From 38badbb8b5fa1e5c34cce3a574d631a6f688a08b Mon Sep 17 00:00:00 2001 From: Gabriel Moise Date: Fri, 8 Sep 2023 16:24:31 +0100 Subject: [PATCH] Proto/Sapling: Replace legacy bindings with let constructs --- .../lib_protocol/sapling_storage.ml | 228 ++++++++++-------- .../lib_protocol/sapling_validator.ml | 125 +++++----- 2 files changed, 199 insertions(+), 154 deletions(-) diff --git a/src/proto_alpha/lib_protocol/sapling_storage.ml b/src/proto_alpha/lib_protocol/sapling_storage.ml index 1deae40db03d..b2ef8eb25a91 100644 --- a/src/proto_alpha/lib_protocol/sapling_storage.ml +++ b/src/proto_alpha/lib_protocol/sapling_storage.ml @@ -118,13 +118,15 @@ module Commitments : COMMITMENTS = struct let init = Storage.Sapling.commitments_init let get_root_height ctx id node height = + let open Lwt_result_syntax in assert_node node height ; assert_height height ; - Storage.Sapling.Commitments.find (ctx, id) node >|=? function - | ctx, None -> + let+ ctx, cm_opt = Storage.Sapling.Commitments.find (ctx, id) node in + match cm_opt with + | None -> let hash = H.uncommitted ~height in (ctx, hash) - | ctx, Some hash -> (ctx, hash) + | Some hash -> (ctx, hash) let left node = Int64.mul node 2L @@ -150,50 +152,58 @@ module Commitments : COMMITMENTS = struct Post: incremental tree /\ to_list (insert tree height pos cms) = to_list t @ cms *) let rec insert ctx id node height pos cms = + let open Lwt_result_syntax in assert_node node height ; assert_height height ; assert_pos pos height ; match (height, cms) with | _, [] -> - get_root_height ctx id node height >|=? fun (ctx, h) -> (ctx, 0, h) + let+ ctx, h = get_root_height ctx id node height in + (ctx, 0, h) | 0, [cm] -> let h = H.of_commitment cm in - Storage.Sapling.Commitments.init (ctx, id) node h - >|=? fun (ctx, size) -> (ctx, size, h) + let+ ctx, size = Storage.Sapling.Commitments.init (ctx, id) node h in + (ctx, size, h) | _ -> let height = height - 1 in - (if Compare.Int64.(pos < pow2 height) then - let at = Int64.(sub (pow2 height) pos) in - let cml, cmr = split_at at cms in - insert ctx id (left node) height pos cml >>=? fun (ctx, size_l, hl) -> - insert ctx id (right node) height 0L cmr >|=? fun (ctx, size_r, hr) -> - (ctx, size_l + size_r, hl, hr) - else - get_root_height ctx id (left node) height >>=? fun (ctx, hl) -> - let pos = Int64.(sub pos (pow2 height)) in - insert ctx id (right node) height pos cms - >|=? fun (ctx, size_r, hr) -> (ctx, size_r, hl, hr)) - >>=? fun (ctx, size_children, hl, hr) -> + let* ctx, size_children, hl, hr = + if Compare.Int64.(pos < pow2 height) then + let at = Int64.(sub (pow2 height) pos) in + let cml, cmr = split_at at cms in + let* ctx, size_l, hl = insert ctx id (left node) height pos cml in + let+ ctx, size_r, hr = insert ctx id (right node) height 0L cmr in + (ctx, size_l + size_r, hl, hr) + else + let* ctx, hl = get_root_height ctx id (left node) height in + let pos = Int64.(sub pos (pow2 height)) in + let+ ctx, size_r, hr = insert ctx id (right node) height pos cms in + (ctx, size_r, hl, hr) + in let h = H.merkle_hash ~height hl hr in - Storage.Sapling.Commitments.add (ctx, id) node h - >|=? fun (ctx, size, _existing) -> (ctx, size + size_children, h) + let+ ctx, size, _existing = + Storage.Sapling.Commitments.add (ctx, id) node h + in + (ctx, size + size_children, h) let rec fold_from_height ctx id node ~pos ~f ~acc height = + let open Lwt_result_syntax in assert_node node height ; assert_height height ; assert_pos pos height ; - Storage.Sapling.Commitments.find (ctx, id) node - (* we don't count gas for this function, it is called only by RPC *) - >>=? - function - | _ctx, None -> return acc - | _ctx, Some h -> + let* _ctx, cm_opt = + Storage.Sapling.Commitments.find (ctx, id) node + (* we don't count gas for this function, it is called only by RPC *) + in + match cm_opt with + | None -> return acc + | Some h -> if Compare.Int.(height = 0) then return (f acc h) else let full = pow2 (height - 1) in if Compare.Int64.(pos < full) then - fold_from_height ctx id (left node) ~pos ~f ~acc (height - 1) - >>=? fun acc -> + let* acc = + fold_from_height ctx id (left node) ~pos ~f ~acc (height - 1) + in (* Setting pos to 0 folds on the whole right subtree *) fold_from_height ctx id (right node) ~pos:0L ~f ~acc (height - 1) else @@ -212,23 +222,27 @@ module Commitments : COMMITMENTS = struct list of commitments. The use of [split_at] has O(n logn) complexity that is less relevant on a smaller list. *) let add ctx id cms pos = + let open Lwt_result_syntax in let l = List.length cms in assert (Compare.Int.(l <= 1000)) ; let n' = Int64.(add pos (of_int l)) in assert (Compare.Int64.(n' <= max_size)) ; - insert ctx id root_node max_height pos cms >|=? fun (ctx, size, _h) -> + let+ ctx, size, _h = insert ctx id root_node max_height pos cms in (ctx, size) let get_from ctx id pos = - fold_from_height - ctx - id - root_node - ~pos - ~f:(fun acc c -> H.to_commitment c :: acc) - ~acc:[] - max_height - >|=? fun l -> List.rev l + let open Lwt_result_syntax in + let+ l = + fold_from_height + ctx + id + root_node + ~pos + ~f:(fun acc c -> H.to_commitment c :: acc) + ~acc:[] + max_height + in + List.rev l end module Ciphertexts = struct @@ -239,8 +253,9 @@ module Ciphertexts = struct let add ctx id c pos = Storage.Sapling.Ciphertexts.init (ctx, id) pos c let get_from ctx id offset = + let open Lwt_result_syntax in let rec aux (ctx, acc) pos = - Storage.Sapling.Ciphertexts.find (ctx, id) pos >>=? fun (ctx, c) -> + let* ctx, c = Storage.Sapling.Ciphertexts.find (ctx, id) pos in match c with | None -> return (ctx, List.rev acc) | Some c -> aux (ctx, c :: acc) (Int64.succ pos) @@ -261,22 +276,27 @@ module Nullifiers = struct (* Allows for duplicates as they are already checked by verify_update before updating the state. *) let add ctx id nfs = - size ctx id >>=? fun nf_start_pos -> - List.fold_left_es - (fun (ctx, pos, acc_size) nf -> - Storage.Sapling.Nullifiers_hashed.init (ctx, id) nf - >>=? fun (ctx, size) -> - Storage.Sapling.Nullifiers_ordered.init (ctx, id) pos nf >|=? fun ctx -> - (ctx, Int64.succ pos, Z.add acc_size (Z.of_int size))) - (ctx, nf_start_pos, Z.zero) - (List.rev nfs) - >>=? fun (ctx, nf_end_pos, size) -> - Storage.Sapling.Nullifiers_size.update (ctx, id) nf_end_pos >|=? fun ctx -> + let open Lwt_result_syntax in + let* nf_start_pos = size ctx id in + let* ctx, nf_end_pos, size = + List.fold_left_es + (fun (ctx, pos, acc_size) nf -> + let* ctx, size = + Storage.Sapling.Nullifiers_hashed.init (ctx, id) nf + in + let+ ctx = Storage.Sapling.Nullifiers_ordered.init (ctx, id) pos nf in + (ctx, Int64.succ pos, Z.add acc_size (Z.of_int size))) + (ctx, nf_start_pos, Z.zero) + (List.rev nfs) + in + let+ ctx = Storage.Sapling.Nullifiers_size.update (ctx, id) nf_end_pos in (ctx, size) let get_from ctx id offset = + let open Lwt_result_syntax in let rec aux acc pos = - Storage.Sapling.Nullifiers_ordered.find (ctx, id) pos >>=? function + let* nf_opt = Storage.Sapling.Nullifiers_ordered.find (ctx, id) pos in + match nf_opt with | None -> return @@ List.rev acc | Some c -> aux (c :: acc) (Int64.succ pos) in @@ -298,49 +318,57 @@ module Roots = struct (* pos is the index of the last inserted element *) let get ctx id = - Storage.Sapling.Roots_pos.get (ctx, id) >>=? fun pos -> + let open Lwt_result_syntax in + let* pos = Storage.Sapling.Roots_pos.get (ctx, id) in Storage.Sapling.Roots.get (ctx, id) pos let init ctx id = + let open Lwt_result_syntax in let rec aux ctx pos = if Compare.Int32.(pos < 0l) then return ctx else - Storage.Sapling.Roots.init (ctx, id) pos Commitments.default_root - >>=? fun ctx -> aux ctx (Int32.pred pos) + let* ctx = + Storage.Sapling.Roots.init (ctx, id) pos Commitments.default_root + in + aux ctx (Int32.pred pos) in - aux ctx (Int32.pred size) >>=? fun ctx -> - Storage.Sapling.Roots_pos.init (ctx, id) 0l >>=? fun ctx -> + let* ctx = aux ctx (Int32.pred size) in + let* ctx = Storage.Sapling.Roots_pos.init (ctx, id) 0l in let level = (Raw_context.current_level ctx).level in Storage.Sapling.Roots_level.init (ctx, id) level let mem ctx id root = - Storage.Sapling.Roots_pos.get (ctx, id) >>=? fun start_pos -> + let open Lwt_result_syntax in + let* start_pos = Storage.Sapling.Roots_pos.get (ctx, id) in let rec aux pos = - Storage.Sapling.Roots.get (ctx, id) pos >>=? fun hash -> - if Compare.Int.(Sapling.Hash.compare hash root = 0) then return true + let* hash = Storage.Sapling.Roots.get (ctx, id) pos in + if Compare.Int.(Sapling.Hash.compare hash root = 0) then return_true else let pos = Int32.(pred pos) in let pos = if Compare.Int32.(pos < 0l) then Int32.pred size else pos in - if Compare.Int32.(pos = start_pos) then return false else aux pos + if Compare.Int32.(pos = start_pos) then return_false else aux pos in aux start_pos (* allows duplicates *) let add ctx id root = - Storage.Sapling.Roots_pos.get (ctx, id) >>=? fun pos -> + let open Lwt_result_syntax in + let* pos = Storage.Sapling.Roots_pos.get (ctx, id) in let level = (Raw_context.current_level ctx).level in - Storage.Sapling.Roots_level.get (ctx, id) >>=? fun stored_level -> + let* stored_level = Storage.Sapling.Roots_level.get (ctx, id) in if Raw_level_repr.(stored_level = level) then (* if there is another add during the same level, it will over-write on the same position *) - Storage.Sapling.Roots.add (ctx, id) pos root >|= ok + let*! ctx = Storage.Sapling.Roots.add (ctx, id) pos root in + return ctx else (* it's the first add for this level *) (* TODO(samoht): why is it using [update] and not [init] then? *) - Storage.Sapling.Roots_level.update (ctx, id) level >>=? fun ctx -> + let* ctx = Storage.Sapling.Roots_level.update (ctx, id) level in let pos = Int32.rem (Int32.succ pos) size in - Storage.Sapling.Roots_pos.update (ctx, id) pos >>=? fun ctx -> - Storage.Sapling.Roots.add (ctx, id) pos root >|= ok + let* ctx = Storage.Sapling.Roots_pos.update (ctx, id) pos in + let*! ctx = Storage.Sapling.Roots.add (ctx, id) pos root in + return ctx end (** This type links the permanent state stored in the context at the specified @@ -361,7 +389,8 @@ let empty_state ?id ~memo_size () = {id; diff = empty_diff; memo_size} (** Returns a state from an existing id. *) let state_from_id ctxt id = - Storage.Sapling.Memo_size.get (ctxt, id) >|=? fun memo_size -> + let open Lwt_result_syntax in + let+ memo_size = Storage.Sapling.Memo_size.get (ctxt, id) in ({id = Some id; diff = empty_diff; memo_size}, ctxt) let rpc_arg = Storage.Sapling.rpc_arg @@ -369,46 +398,53 @@ let rpc_arg = Storage.Sapling.rpc_arg let get_memo_size ctx id = Storage.Sapling.Memo_size.get (ctx, id) let init ctx id ~memo_size = - Storage.Sapling.Memo_size.add (ctx, id) memo_size >>= fun ctx -> - Storage.Sapling.Commitments_size.add (ctx, id) Int64.zero >>= fun ctx -> - Commitments.init ctx id >>= fun ctx -> - Nullifiers.init ctx id >>= fun ctx -> - Roots.init ctx id >>=? fun ctx -> Ciphertexts.init ctx id >|= ok + let open Lwt_result_syntax in + let*! ctx = Storage.Sapling.Memo_size.add (ctx, id) memo_size in + let*! ctx = Storage.Sapling.Commitments_size.add (ctx, id) Int64.zero in + let*! ctx = Commitments.init ctx id in + let*! ctx = Nullifiers.init ctx id in + let* ctx = Roots.init ctx id in + let*! ctx = Ciphertexts.init ctx id in + return ctx (** Applies a diff to a state id stored in the context. Updates Commitments, Ciphertexts and Nullifiers using the diff and updates the Roots using the new Commitments tree. *) let apply_diff ctx id diff = + let open Lwt_result_syntax in let open Sapling_repr in let nb_commitments = List.length diff.commitments_and_ciphertexts in let nb_nullifiers = List.length diff.nullifiers in let sapling_cost = Sapling_storage_costs.cost_SAPLING_APPLY_DIFF nb_nullifiers nb_commitments in - Raw_context.consume_gas ctx sapling_cost >>?= fun ctx -> - Storage.Sapling.Commitments_size.get (ctx, id) >>=? fun cm_start_pos -> + let*? ctx = Raw_context.consume_gas ctx sapling_cost in + let* cm_start_pos = Storage.Sapling.Commitments_size.get (ctx, id) in let cms = List.rev_map fst diff.commitments_and_ciphertexts in - Commitments.add ctx id cms cm_start_pos >>=? fun (ctx, size) -> - Storage.Sapling.Commitments_size.update - (ctx, id) - (Int64.add cm_start_pos (Int64.of_int nb_commitments)) - >>=? fun ctx -> - List.fold_left_es - (fun (ctx, pos, acc_size) (_cm, cp) -> - Ciphertexts.add ctx id cp pos >|=? fun (ctx, size) -> - (ctx, Int64.succ pos, Z.add acc_size (Z.of_int size))) - (ctx, cm_start_pos, Z.of_int size) - (List.rev diff.commitments_and_ciphertexts) - >>=? fun (ctx, _ct_end_pos, size) -> - Nullifiers.add ctx id diff.nullifiers >>=? fun (ctx, size_nf) -> + let* ctx, size = Commitments.add ctx id cms cm_start_pos in + let* ctx = + Storage.Sapling.Commitments_size.update + (ctx, id) + (Int64.add cm_start_pos (Int64.of_int nb_commitments)) + in + let* ctx, _ct_end_pos, size = + List.fold_left_es + (fun (ctx, pos, acc_size) (_cm, cp) -> + let+ ctx, size = Ciphertexts.add ctx id cp pos in + (ctx, Int64.succ pos, Z.add acc_size (Z.of_int size))) + (ctx, cm_start_pos, Z.of_int size) + (List.rev diff.commitments_and_ciphertexts) + in + let* ctx, size_nf = Nullifiers.add ctx id diff.nullifiers in let size = Z.add size size_nf in match diff.commitments_and_ciphertexts with | [] -> (* avoids adding duplicates to Roots *) return (ctx, size) | _ :: _ -> - Commitments.get_root ctx id >>=? fun (ctx, root) -> - Roots.add ctx id root >|=? fun ctx -> (ctx, size) + let* ctx, root = Commitments.get_root ctx id in + let+ ctx = Roots.add ctx id root in + (ctx, size) let add {id; diff; memo_size} cm_cipher_list = assert ( @@ -458,18 +494,20 @@ type root = Sapling.Hash.t let root_encoding = Sapling.Hash.encoding let get_diff ctx id ?(offset_commitment = 0L) ?(offset_nullifier = 0L) () = + let open Lwt_result_syntax in if not Sapling.Commitment.( valid_position offset_commitment && valid_position offset_nullifier) then failwith "Invalid argument." else - Commitments.get_from ctx id offset_commitment >>=? fun commitments -> - Roots.get ctx id >>=? fun root -> - Nullifiers.get_from ctx id offset_nullifier >>=? fun nullifiers -> - Ciphertexts.get_from ctx id offset_commitment - (* we don't count gas for RPCs *) - >|=? fun (_ctx, ciphertexts) -> + let* commitments = Commitments.get_from ctx id offset_commitment in + let* root = Roots.get ctx id in + let* nullifiers = Nullifiers.get_from ctx id offset_nullifier in + let+ _ctx, ciphertexts = + Ciphertexts.get_from ctx id offset_commitment + (* we don't count gas for RPCs *) + in match List.combine ~when_different_lengths:() commitments ciphertexts with | Error () -> failwith "Invalid argument." | Ok commitments_and_ciphertexts -> diff --git a/src/proto_alpha/lib_protocol/sapling_validator.ml b/src/proto_alpha/lib_protocol/sapling_validator.ml index a9784cae9421..97f1a3460fc0 100644 --- a/src/proto_alpha/lib_protocol/sapling_validator.ml +++ b/src/proto_alpha/lib_protocol/sapling_validator.ml @@ -26,17 +26,19 @@ (* Check that each nullifier is not already present in the state and add it. Important to avoid spending the same input twice in a transaction. *) let rec check_and_update_nullifiers ctxt state inputs = + let open Lwt_result_syntax in match inputs with | [] -> return (ctxt, Some state) - | input :: inputs -> ( - Sapling_storage.nullifiers_mem ctxt state Sapling.UTXO.(input.nf) - >>=? function - | ctxt, true -> return (ctxt, None) - | ctxt, false -> - let state = - Sapling_storage.nullifiers_add state Sapling.UTXO.(input.nf) - in - check_and_update_nullifiers ctxt state inputs) + | input :: inputs -> + let* ctxt, nullifier_in_state = + Sapling_storage.nullifiers_mem ctxt state Sapling.UTXO.(input.nf) + in + if nullifier_in_state then return (ctxt, None) + else + let state = + Sapling_storage.nullifiers_add state Sapling.UTXO.(input.nf) + in + check_and_update_nullifiers ctxt state inputs let verify_update : Raw_context.t -> @@ -44,65 +46,70 @@ let verify_update : Sapling_repr.transaction -> string -> (Raw_context.t * (Int64.t * Sapling_storage.state) option) tzresult Lwt.t = - fun ctxt state transaction key -> - (* Check the transaction *) - (* To avoid overflowing the balance, the number of inputs and outputs must be - bounded. - Ciphertexts' memo_size must match the state's memo_size. - These constraints are already enforced at the encoding level. *) - assert (Compare.Int.(List.compare_length_with transaction.inputs 5208 <= 0)) ; - assert (Compare.Int.(List.compare_length_with transaction.outputs 2019 <= 0)) ; - let pass = - List.for_all - (fun output -> - Compare.Int.( - Sapling.Ciphertext.get_memo_size Sapling.UTXO.(output.ciphertext) - = state.memo_size)) - transaction.outputs - in - if not pass then return (ctxt, None) - else - (* Check the root is a recent state *) - Sapling_storage.root_mem ctxt state transaction.root >>=? fun pass -> + let open Lwt_result_syntax in + fun ctxt state transaction key -> + (* Check the transaction *) + (* To avoid overflowing the balance, the number of inputs and outputs must be + bounded. + Ciphertexts' memo_size must match the state's memo_size. + These constraints are already enforced at the encoding level. *) + assert (Compare.Int.(List.compare_length_with transaction.inputs 5208 <= 0)) ; + assert (Compare.Int.(List.compare_length_with transaction.outputs 2019 <= 0)) ; + let pass = + List.for_all + (fun output -> + Compare.Int.( + Sapling.Ciphertext.get_memo_size Sapling.UTXO.(output.ciphertext) + = state.memo_size)) + transaction.outputs + in if not pass then return (ctxt, None) else - check_and_update_nullifiers ctxt state transaction.inputs >|=? function - | ctxt, None -> (ctxt, None) - | ctxt, Some state -> - Sapling.Verification.with_verification_ctx (fun vctx -> - let pass = - (* Check all the output ZK proofs *) - List.for_all - (fun output -> Sapling.Verification.check_output vctx output) - transaction.outputs - in - if not pass then (ctxt, None) - else + (* Check the root is a recent state *) + let* pass = Sapling_storage.root_mem ctxt state transaction.root in + if not pass then return (ctxt, None) + else + let+ ctxt, state_opt = + check_and_update_nullifiers ctxt state transaction.inputs + in + match state_opt with + | None -> (ctxt, None) + | Some state -> + Sapling.Verification.with_verification_ctx (fun vctx -> let pass = - (* Check all the input Zk proofs and signatures *) + (* Check all the output ZK proofs *) List.for_all - (fun input -> - Sapling.Verification.check_spend - vctx - input - transaction.root - key) - transaction.inputs + (fun output -> + Sapling.Verification.check_output vctx output) + transaction.outputs in if not pass then (ctxt, None) else let pass = - (* Check the signature and balance of the whole transaction *) - Sapling.Verification.final_check vctx transaction key + (* Check all the input Zk proofs and signatures *) + List.for_all + (fun input -> + Sapling.Verification.check_spend + vctx + input + transaction.root + key) + transaction.inputs in if not pass then (ctxt, None) else - (* update tree *) - let list_to_add = - List.map - (fun output -> - Sapling.UTXO.(output.cm, output.ciphertext)) - transaction.outputs + let pass = + (* Check the signature and balance of the whole transaction *) + Sapling.Verification.final_check vctx transaction key in - let state = Sapling_storage.add state list_to_add in - (ctxt, Some (transaction.balance, state))) + if not pass then (ctxt, None) + else + (* update tree *) + let list_to_add = + List.map + (fun output -> + Sapling.UTXO.(output.cm, output.ciphertext)) + transaction.outputs + in + let state = Sapling_storage.add state list_to_add in + (ctxt, Some (transaction.balance, state))) -- GitLab