From 749e7fe47ead44e332c6d1a499e09d0498d66a77 Mon Sep 17 00:00:00 2001 From: Thomas Letan Date: Thu, 28 Jul 2022 15:55:47 +0200 Subject: [PATCH] SCORU,WASM: Provide combinators to deal with absent keys --- src/lib_scoru_wasm/test/test_encoding.ml | 121 ++++++++++++++---- src/lib_scoru_wasm/tree_decoding.ml | 52 ++++++-- src/lib_scoru_wasm/tree_decoding.mli | 37 ++++-- src/lib_scoru_wasm/tree_encoding.ml | 7 + src/lib_scoru_wasm/tree_encoding.mli | 5 + src/lib_scoru_wasm/tree_encoding_decoding.ml | 20 ++- src/lib_scoru_wasm/tree_encoding_decoding.mli | 20 ++- 7 files changed, 202 insertions(+), 60 deletions(-) diff --git a/src/lib_scoru_wasm/test/test_encoding.ml b/src/lib_scoru_wasm/test/test_encoding.ml index c0ec80abf954..363f4aa6477e 100644 --- a/src/lib_scoru_wasm/test/test_encoding.ml +++ b/src/lib_scoru_wasm/test/test_encoding.ml @@ -27,7 +27,7 @@ ------- Component: Tree_encoding_decoding Invocation: dune exec src/lib_scoru_wasm/test/test_scoru_wasm.exe \ - -- test "$Encodings^" + -- test "^Encodings$" Subject: Encoding tests for the tezos-scoru-wasm library *) @@ -84,8 +84,31 @@ let test_encode_decode enc value f = let*! value' = Merklizer.decode enc tree in f value' +let test_decode_encode_decode tree enc f = + let open Lwt_syntax in + let* value = Merklizer.decode enc tree in + let* tree = Merklizer.encode enc value tree in + let* value' = Merklizer.decode enc tree in + f value value' + let encode_decode enc value = test_encode_decode enc value Lwt.return +let decode_encode_decode tree enc = + test_decode_encode_decode tree enc (fun x y -> Lwt.return (x, y)) + +let assert_value tree enc v = + let open Lwt_result_syntax in + let*! v' = Merklizer.decode enc tree in + assert (v = v') ; + return_unit + +let assert_missing_value tree key = + let open Lwt_result_syntax in + let*! candidate = Tree.find tree key in + match candidate with + | None -> return_unit + | Some _ -> failwith "value should be missing" + let assert_round_trip enc value equal = let open Lwt_syntax in let* value' = encode_decode enc value in @@ -93,6 +116,12 @@ let assert_round_trip enc value equal = assert (equal value' value) ; return_unit +let assert_decode_round_trip tree enc equal = + let open Lwt_result_syntax in + let*! value, value' = decode_encode_decode tree enc in + assert (equal value' value) ; + return_unit + let test_string () = let enc = Merklizer.value ["key"] Data_encoding.string in assert_round_trip enc "Hello" String.equal @@ -123,34 +152,36 @@ type contact = | Address of {street : string; number : int} | No_address -let test_tagged_union () = +let contact_enc ?default () = let open Merklizer in + tagged_union + ?default + (value [] Data_encoding.string) + [ + case + "Email" + (value [] Data_encoding.string) + (function Email s -> Some s | _ -> None) + (fun s -> Email s); + case + "Address" + (tup2 + ~flatten:false + (value ["street"] Data_encoding.string) + (value ["number"] Data_encoding.int31)) + (function + | Address {street; number} -> Some (street, number) | _ -> None) + (fun (street, number) -> Address {street; number}); + case + "No Address" + (value [] Data_encoding.unit) + (function No_address -> Some () | _ -> None) + (fun () -> No_address); + ] + +let test_tagged_union () = let open Lwt_result_syntax in - let enc = - tagged_union - (value [] Data_encoding.string) - [ - case - "Email" - (value [] Data_encoding.string) - (function Email s -> Some s | _ -> None) - (fun s -> Email s); - case - "Address" - (tup2 - ~flatten:false - (value ["street"] Data_encoding.string) - (value ["number"] Data_encoding.int31)) - (function - | Address {street; number} -> Some (street, number) | _ -> None) - (fun (street, number) -> Address {street; number}); - case - "No Address" - (value [] Data_encoding.unit) - (function No_address -> Some () | _ -> None) - (fun () -> No_address); - ] - in + let enc = contact_enc () in let* () = assert_round_trip enc No_address Stdlib.( = ) in let* () = assert_round_trip enc (Email "foo@bar.com") Stdlib.( = ) in let* () = @@ -161,6 +192,20 @@ let test_tagged_union () = in return_unit +let test_tagged_union_default () = + let open Lwt_result_syntax in + let enc = contact_enc ~default:No_address () in + let*! empty_tree = empty_tree () in + let* () = assert_value empty_tree enc No_address in + let* () = assert_round_trip enc No_address Stdlib.( = ) in + let* () = + assert_round_trip + enc + (Address {street = "Main Street"; number = 10}) + Stdlib.( = ) + in + return_unit + let test_lazy_mapping () = let open Merklizer in let open Lwt_result_syntax in @@ -315,6 +360,25 @@ let test_value_option () = let* () = assert_round_trip enc None Stdlib.( = ) in return_unit +let test_value_default () = + let open Merklizer in + let open Lwt_result_syntax in + let*! tree = empty_tree () in + let enc = value ~default:42 [] Data_encoding.int31 in + assert_value tree enc 42 + +let test_optional () = + let open Merklizer in + let open Lwt_result_syntax in + let key = [] in + let enc = optional key Data_encoding.int31 in + let*! tree = empty_tree () in + let*! tree = Merklizer.encode enc (Some 0) tree in + let* () = assert_value tree enc (Some 0) in + let*! tree = Merklizer.encode enc None tree in + let* () = assert_missing_value tree key in + return_unit + type cyclic = {name : string; self : unit -> cyclic} let test_with_self_ref () = @@ -346,6 +410,7 @@ let tests = tztest "Raw" `Quick test_raw; tztest "Convert" `Quick test_conv; tztest "Tagged-union" `Quick test_tagged_union; + tztest "Tagged-union ~default" `Quick test_tagged_union_default; tztest "Lazy mapping" `Quick test_lazy_mapping; tztest "Add element to decoded empty map" @@ -356,5 +421,7 @@ let tests = tztest "Tuples" `Quick test_tuples; tztest "Option" `Quick test_option; tztest "Value Option" `Quick test_value_option; + tztest "Value ~default" `Quick test_value_default; + tztest "Optional" `Quick test_optional; tztest "Self ref" `Quick test_with_self_ref; ] diff --git a/src/lib_scoru_wasm/tree_decoding.ml b/src/lib_scoru_wasm/tree_decoding.ml index 908c845d4c6f..27d01acb5708 100644 --- a/src/lib_scoru_wasm/tree_decoding.ml +++ b/src/lib_scoru_wasm/tree_decoding.ml @@ -42,7 +42,9 @@ module type S = sig val raw : key -> bytes t - val value : key -> 'a Data_encoding.t -> 'a t + val optional : key -> 'a Data_encoding.t -> 'a option t + + val value : ?default:'a -> key -> 'a Data_encoding.t -> 'a t val scope : key -> 'a t -> 'a t @@ -58,7 +60,7 @@ module type S = sig val case_lwt : 'tag -> 'b t -> ('b -> 'a Lwt.t) -> ('tag, 'a) case - val tagged_union : 'tag t -> ('tag, 'a) case list -> 'a t + val tagged_union : ?default:'a -> 'tag t -> ('tag, 'a) case list -> 'a t module Syntax : sig val return : 'a -> 'a t @@ -135,16 +137,24 @@ module Make (T : Tree.S) : S with type tree = T.tree = struct let+ value = Tree.find tree key in match value with Some value -> value | None -> raise (Key_not_found key) - let value key decoder tree prefix = + let optional key decoder tree prefix = let open Lwt_syntax in let key = prefix key in let* value = Tree.find tree key in match value with | Some value -> ( match Data_encoding.Binary.of_bytes decoder value with - | Ok value -> return value + | Ok value -> return_some value | Error error -> raise (Decode_error {key; error})) - | None -> raise (Key_not_found key) + | None -> return_none + + let value ?default key decoder tree prefix = + let open Lwt_syntax in + let* value = optional key decoder tree prefix in + match (value, default) with + | Some value, _ -> return value + | None, Some default -> return default + | None, None -> raise (Key_not_found (prefix key)) let scope key dec tree prefix = dec tree (append_key prefix key) @@ -159,14 +169,28 @@ module Make (T : Tree.S) : S with type tree = T.tree = struct let case tag decode extract = case_lwt tag decode (fun x -> Lwt.return @@ extract x) - let tagged_union decode_tag cases input_tree prefix = + let tagged_union ?default decode_tag cases input_tree prefix = let open Lwt_syntax in - let* target_tag = scope ["tag"] decode_tag input_tree prefix in - (* Search through the cases to find a matching branch. *) - cases - |> List.find_map (fun (Case {tag; decode; extract}) -> - if tag = target_tag then - Some (map_lwt extract (scope ["value"] decode) input_tree prefix) - else None) - |> Option.value_f ~default:(fun _ -> raise No_tag_matched_on_decoding) + Lwt.try_bind + (fun () -> scope ["tag"] decode_tag input_tree prefix) + (fun target_tag -> + (* Search through the cases to find a matching branch. *) + let candidate = + List.find_map + (fun (Case {tag; decode; extract}) -> + if tag = target_tag then + Some + (map_lwt extract (scope ["value"] decode) input_tree prefix) + else None) + cases + in + match candidate with + | Some case -> case + | None -> raise No_tag_matched_on_decoding) + (function + | Key_not_found _ as exn -> ( + match default with + | Some default -> return default + | None -> raise exn) + | exn -> raise exn) end diff --git a/src/lib_scoru_wasm/tree_decoding.mli b/src/lib_scoru_wasm/tree_decoding.mli index 88b17c07e720..38bf255ccb3c 100644 --- a/src/lib_scoru_wasm/tree_decoding.mli +++ b/src/lib_scoru_wasm/tree_decoding.mli @@ -54,13 +54,26 @@ module type S = sig *) val raw : key -> bytes t - (** [value key data_encoding] retrieves the value at a given [key] by decoding - its raw value using the provided [data_encoding]. + (** [optional key data_encoding] tries to retrieve the value at a + given [key] by decoding its raw value using the provided + [data_encoding], or return [None] if [key] is missing. + + @raises Decode_error when decoding of the value fails + *) + val optional : key -> 'a Data_encoding.t -> 'a option t + + (** [value ?default key data_encoding] retrieves the value at a + given [key] by decoding its raw value using the provided + [data_encoding]. + + The [default] labeled argument can be provided to specify a + fallback value for when the key is absent from the tree. @raises Key_not_found when the requested key is not presented + and the [default] argument is omitted. @raises Decode_error when decoding of the value fails *) - val value : key -> 'a Data_encoding.t -> 'a t + val value : ?default:'a -> key -> 'a Data_encoding.t -> 'a t (** [scope key decoder] applies a tree decoder for a provided [key]. @@ -98,15 +111,21 @@ module type S = sig an [Lwt] value. *) val case_lwt : 'tag -> 'b t -> ('b -> 'a Lwt.t) -> ('tag, 'a) case - (** [tagged_union tag_dec cases] returns a decoder that use [tag_dec] for - decoding the value of a field [tag]. The decoder searches through the list - of cases for a matching branch. When a matching branch is found, it uses - its embedded decoder for the value. This function is used for constructing - decoders for sum-types. + (** [tagged_union ?default tag_dec cases] returns a decoder that use + [tag_dec] for decoding the value of a field [tag]. The decoder + searches through the list of cases for a matching branch. When a + matching branch is found, it uses its embedded decoder for the + value. This function is used for constructing decoders for + sum-types. + + [default] is an optional labeled argument that can be provided + in order to have a fallback to use in case the tag is absent + from the tree (which means, the value has not yet been + initialized in the tree). If an insufficient list of cases are provided, the resulting encoder may fail with a [No_tag_matched] error when [run]. *) - val tagged_union : 'tag t -> ('tag, 'a) case list -> 'a t + val tagged_union : ?default:'a -> 'tag t -> ('tag, 'a) case list -> 'a t (** Syntax module for the {!Tree_decoding}. This is intended to be opened locally in functions. Within the scope of this module, the code can diff --git a/src/lib_scoru_wasm/tree_encoding.ml b/src/lib_scoru_wasm/tree_encoding.ml index 681895242b6f..0594374d2cef 100644 --- a/src/lib_scoru_wasm/tree_encoding.ml +++ b/src/lib_scoru_wasm/tree_encoding.ml @@ -40,6 +40,8 @@ module type S = sig val raw : key -> bytes t + val optional : key -> 'a Data_encoding.t -> 'a option t + val value : key -> 'a Data_encoding.t -> 'a t val scope : key -> 'a t -> 'a t @@ -103,6 +105,11 @@ module Make (T : Tree.S) = struct let value suffix enc = contramap (Data_encoding.Binary.to_bytes_exn enc) (raw suffix) + let optional key encoding v prefix tree = + match v with + | Some v -> value key encoding v prefix tree + | None -> T.remove tree (prefix key) + let scope key enc value prefix tree = enc value (append_key prefix key) tree let lazy_mapping to_key enc_value bindings prefix tree = diff --git a/src/lib_scoru_wasm/tree_encoding.mli b/src/lib_scoru_wasm/tree_encoding.mli index 5b7c73812c5e..d6a70ae18b84 100644 --- a/src/lib_scoru_wasm/tree_encoding.mli +++ b/src/lib_scoru_wasm/tree_encoding.mli @@ -54,6 +54,11 @@ module type S = sig (** [raw key] returns an encoder that encodes raw bytes at the given key. *) val raw : key -> bytes t + (** [optional key enc] encodes the value at a given [key] using the + provided [enc] encoder for the value, or remove any previous + value stored at [key] if [None] is provided. *) + val optional : key -> 'a Data_encoding.t -> 'a option t + (** [value key enc] encodes the value at a given [key] using the provided [enc] encoder for the value. *) val value : key -> 'a Data_encoding.t -> 'a t diff --git a/src/lib_scoru_wasm/tree_encoding_decoding.ml b/src/lib_scoru_wasm/tree_encoding_decoding.ml index dab460e2a86b..b160fe4fd80f 100644 --- a/src/lib_scoru_wasm/tree_encoding_decoding.ml +++ b/src/lib_scoru_wasm/tree_encoding_decoding.ml @@ -122,7 +122,9 @@ module type S = sig val raw : key -> bytes t - val value : key -> 'a Data_encoding.t -> 'a t + val optional : key -> 'a Data_encoding.t -> 'a option t + + val value : ?default:'a -> key -> 'a Data_encoding.t -> 'a t val value_option : key -> 'a Data_encoding.t -> 'a option t @@ -145,7 +147,7 @@ module type S = sig ('b -> 'a Lwt.t) -> ('tag, 'a) case - val tagged_union : 'tag t -> ('tag, 'a) case list -> 'a t + val tagged_union : ?default:'a -> 'tag t -> ('tag, 'a) case list -> 'a t val option : 'a t -> 'a option t @@ -320,7 +322,8 @@ module Make let raw key = {encode = E.raw key; decode = D.raw key} - let value key de = {encode = E.value key de; decode = D.value key de} + let value ?default key de = + {encode = E.value key de; decode = D.value ?default key de} let value_option key de = value key (Data_encoding.option de) @@ -400,7 +403,7 @@ module Make (fun x -> Option.map Lwt.return @@ probe x) (fun x -> Lwt.return @@ extract x) - let tagged_union {encode; decode} cases = + let tagged_union ?default {encode; decode} cases = let to_encode_case (Case {tag; delegate; probe; extract = _}) = E.case_lwt tag delegate.encode probe in @@ -408,7 +411,14 @@ module Make D.case_lwt tag delegate.decode extract in let encode = E.tagged_union encode (List.map to_encode_case cases) in - let decode = D.tagged_union decode (List.map to_decode_case cases) in + let decode = + D.tagged_union ?default decode (List.map to_decode_case cases) + in + {encode; decode} + + let optional key encoding = + let encode = E.optional key encoding in + let decode = D.optional key encoding in {encode; decode} let option enc = diff --git a/src/lib_scoru_wasm/tree_encoding_decoding.mli b/src/lib_scoru_wasm/tree_encoding_decoding.mli index f5585457b87d..c22087252c2e 100644 --- a/src/lib_scoru_wasm/tree_encoding_decoding.mli +++ b/src/lib_scoru_wasm/tree_encoding_decoding.mli @@ -192,9 +192,16 @@ module type S = sig (** [raw key] is an encoder for bytes under the given [key]. *) val raw : key -> bytes t - (** [value key enc] creates an encoder under the given [key] using the - provided data-encoding [enc] for encoding/decoding values. *) - val value : key -> 'a Data_encoding.t -> 'a t + (** [optional key encoding] returns an encoder that uses [encoding] + for encoding values, but does not fail if the [key] is + absent. *) + val optional : key -> 'a Data_encoding.t -> 'a option t + + (** [value ?default key enc] creates an encoder under the given + [key] using the provided data-encoding [enc] for + encoding/decoding values, and using [default] as a fallback when + decoding in case the [key] is absent from the tree. *) + val value : ?default:'a -> key -> 'a Data_encoding.t -> 'a t (** [value_option key enc] creates an encoder for optional values under the given [key] using the provided data-encoding [enc]. Note that the value is @@ -239,8 +246,11 @@ module type S = sig encoding the value of a field [tag]. The encoder searches through the list of cases for a matching branch. When a matching branch is found, it uses its embedded encoder for the value. This function is used for constructing - encoders for sum-types. *) - val tagged_union : 'tag t -> ('tag, 'a) case list -> 'a t + encoders for sum-types. + + The [default] labeled argument can be provided to have a + fallback in case the value is missing from the tree. *) + val tagged_union : ?default:'a -> 'tag t -> ('tag, 'a) case list -> 'a t (** [option enc] lifts the given encoding [enc] to one that can encode optional values. *) -- GitLab