diff --git a/src/lib_scoru_wasm/init_encodings.ml b/src/lib_scoru_wasm/init_encodings.ml index d061f328f8ea871bf5a3b51107c8f35e0f736fdb..89abb29e60f159ecef0ae2cd69792920de80be29 100644 --- a/src/lib_scoru_wasm/init_encodings.ml +++ b/src/lib_scoru_wasm/init_encodings.ml @@ -50,6 +50,13 @@ let map_kont_encoding enc_a enc_b = (scope ["destination"] enc_b) (value ["offset"] Data_encoding.int32) +let tick_map_kont_encoding enc_kont enc_a enc_b = + conv (fun (tick, map) -> {tick; map}) (fun {tick; map} -> (tick, map)) + @@ tup2 + ~flatten:true + (option (scope ["inner_kont"] enc_kont)) + (scope ["map_kont"] (map_kont_encoding enc_a enc_b)) + let concat_kont_encoding enc_a = conv (fun (lv, rv, res, offset) -> {lv; rv; res; offset}) @@ -76,10 +83,10 @@ let lazy_vec_encoding enc = int32_lazy_vector (value [] Data_encoding.int32) enc type (_, _) eq = Eq : ('a, 'a) eq let init_section_eq : - type a b c d. - (a, b) init_section -> - (c, d) init_section -> - ((a, b) init_section, (c, d) init_section) eq option = + type kont kont' a b c d. + (kont, a, b) init_section -> + (kont', c, d) init_section -> + ((kont, a, b) init_section, (kont', c, d) init_section) eq option = fun sec1 sec2 -> match (sec1, sec2) with | Func, Func -> Some Eq @@ -89,10 +96,14 @@ let init_section_eq : | _, _ -> None let aggregate_cases : - type a b. - string -> (a, b) init_section -> a t -> b t -> (string, init_kont) case list - = - fun name sec enc_a enc_b -> + type kont a b. + string -> + (kont, a, b) init_section -> + kont t -> + a t -> + b t -> + (string, init_kont) case list = + fun name sec enc_kont enc_a enc_b -> [ case Format.(sprintf "IK_Aggregate_%s" name) @@ -101,7 +112,8 @@ let aggregate_cases : (scope ["module"] Wasm_encoding.module_instance_encoding) (scope ["kont"] - (map_kont_encoding + (tick_map_kont_encoding + enc_kont (lazy_vec_encoding enc_a) (lazy_vec_encoding enc_b)))) (function @@ -126,6 +138,16 @@ let aggregate_cases : (function m, t -> IK_Aggregate_concat (m, sec, t)); ] +let aggregate_cases_either : + type a b. + string -> + ((a, b) Either.t, a, b) init_section -> + a t -> + b t -> + (string, init_kont) case list = + fun name sec enc_a enc_b -> + aggregate_cases name sec (either enc_a enc_b) enc_a enc_b + let join_kont_encoding enc_b = tagged_union tag_encoding @@ -200,22 +222,22 @@ let init_kont_encoding ~host_funcs = (function IK_Type (m, t) -> Some (m, t) | _ -> None) (function m, t -> IK_Type (m, t)); ] - @ aggregate_cases + @ aggregate_cases_either "func" Func Parser.Code.func_encoding Wasm_encoding.function_encoding - @ aggregate_cases + @ aggregate_cases_either "global" Global (value [] Interpreter_encodings.Ast.global_encoding) Wasm_encoding.global_encoding - @ aggregate_cases + @ aggregate_cases_either "table" Table (value [] Interpreter_encodings.Ast.table_encoding) Wasm_encoding.table_encoding - @ aggregate_cases + @ aggregate_cases_either "memory" Memory (value [] Interpreter_encodings.Ast.memory_encoding) diff --git a/src/lib_tree_encoding/tree_encoding.ml b/src/lib_tree_encoding/tree_encoding.ml index d0df478e14de2f7dc1083da41c178197308714f2..d872f4ba8c73c66928fcdd989dd2f118aa7f2765 100644 --- a/src/lib_tree_encoding/tree_encoding.ml +++ b/src/lib_tree_encoding/tree_encoding.ml @@ -327,6 +327,22 @@ let delayed f = in {encode; decode} +let either enc_a enc_b = + tagged_union + (value [] Data_encoding.string) + [ + case + "Left" + enc_a + (function Either.Left x -> Some x | _ -> None) + (function x -> Left x); + case + "Right" + enc_b + (function Either.Right x -> Some x | _ -> None) + (function x -> Right x); + ] + module Runner = struct module type TREE = S diff --git a/src/lib_tree_encoding/tree_encoding.mli b/src/lib_tree_encoding/tree_encoding.mli index 80a54a7ae5953800d6cd0c25babaef75256a4a17..f5071f5f98f1b2ba4d8dab6e401e9133af8c40bd 100644 --- a/src/lib_tree_encoding/tree_encoding.mli +++ b/src/lib_tree_encoding/tree_encoding.mli @@ -234,6 +234,10 @@ val option : 'a t -> 'a option t to allow for directly recursive encoders/decoders. *) val delayed : (unit -> 'a t) -> 'a t +(** [either enc_a enc_b] returns an encoder where [enc_a] is used for + the left case of [Either.t], and [enc_b] for the [Right] case. *) +val either : 'a t -> 'b t -> ('a, 'b) Either.t t + module Runner : sig module type TREE = sig type tree diff --git a/src/lib_webassembly/exec/eval.ml b/src/lib_webassembly/exec/eval.ml index 77e39101acdeae4cc3eac40fe117466150a0e58a..56239cabff20461342d3d110e8c201140cc97631 100644 --- a/src/lib_webassembly/exec/eval.ml +++ b/src/lib_webassembly/exec/eval.ml @@ -1131,14 +1131,44 @@ let fold_left_s_step {origin; acc; offset} f = let+ acc = f acc x in {origin; acc; offset = Int32.succ offset} -type (_, _) init_section = - | Func : (func, func_inst) init_section - | Global : (global, global_inst) init_section - | Table : (table, table_inst) init_section - | Memory : (memory, memory_inst) init_section +type ('kont, 'a, 'b) tick_map_kont = { + tick : 'kont option; + map : ('a, 'b) map_kont; +} + +let tick_map_completed {map; _} = map_completed map + +let tick_map_kont v = {tick = None; map = map_kont v} + +let tick_map_step first_kont kont_completed kont_step = function + | {map; _} when map_completed map -> assert false + | {tick = None; map} -> + let+ x = Vector.get map.offset map.origin in + let tick = first_kont x in + {tick = Some tick; map} + | {tick = Some tick; map} -> ( + match kont_completed tick with + | Some v -> + let map = + { + map with + destination = Vector.set map.offset v map.destination; + offset = Int32.succ map.offset; + } + in + Lwt.return {tick = None; map} + | None -> + let+ tick = kont_step tick in + {tick = Some tick; map}) + +type (_, _, _) init_section = + | Func : ((func, func_inst) Either.t, func, func_inst) init_section + | Global : ((global, global_inst) Either.t, global, global_inst) init_section + | Table : ((table, table_inst) Either.t, table, table_inst) init_section + | Memory : ((memory, memory_inst) Either.t, memory, memory_inst) init_section let section_fetch_vec : - type a b. module_inst -> (a, b) init_section -> b Vector.t = + type kont a b. module_inst -> (kont, a, b) init_section -> b Vector.t = fun inst sec -> match sec with | Func -> inst.funcs @@ -1147,7 +1177,8 @@ let section_fetch_vec : | Memory -> inst.memories let section_set_vec : - type a b. module_inst -> (a, b) init_section -> b Vector.t -> module_inst = + type kont a b. + module_inst -> (kont, a, b) init_section -> b Vector.t -> module_inst = fun inst sec vec -> match (sec, vec) with | Func, funcs -> {inst with funcs} @@ -1220,10 +1251,10 @@ type init_kont = | IK_Add_import of (extern, import, module_inst) fold_right2_kont | IK_Type of module_inst * (type_, func_type) map_kont | IK_Aggregate : - module_inst * ('a, 'b) init_section * ('a, 'b) map_kont + module_inst * ('kont, 'a, 'b) init_section * ('kont, 'a, 'b) tick_map_kont -> init_kont | IK_Aggregate_concat : - module_inst * ('a, 'b) init_section * 'b concat_kont + module_inst * ('kont, 'a, 'b) init_section * 'b concat_kont -> init_kont | IK_Exports of module_inst * (export, extern NameMap.t) fold_left_kont | IK_Elems of module_inst * (elem_segment, elem_inst) map_kont @@ -1238,25 +1269,57 @@ type init_kont = | IK_Stop of module_inst let section_next_init_kont : - type a b. module_ -> module_inst -> (a, b) init_section -> init_kont = + type kont a b. + module_ -> module_inst -> (kont, a, b) init_section -> init_kont = fun m inst0 sec -> match sec with - | Func -> IK_Aggregate (inst0, Global, map_kont m.it.globals) - | Global -> IK_Aggregate (inst0, Table, map_kont m.it.tables) - | Table -> IK_Aggregate (inst0, Memory, map_kont m.it.memories) + | Func -> IK_Aggregate (inst0, Global, tick_map_kont m.it.globals) + | Global -> IK_Aggregate (inst0, Table, tick_map_kont m.it.tables) + | Table -> IK_Aggregate (inst0, Memory, tick_map_kont m.it.memories) | Memory -> IK_Exports (inst0, fold_left_kont m.it.exports (NameMap.create ())) -let section_step : - type a b. - module_inst ModuleMap.t -> module_key -> (a, b) init_section -> a -> b Lwt.t - = - fun module_reg self -> function - | Func -> create_func module_reg self - | Global -> create_global module_reg self - | Table -> fun x -> Lwt.return (create_table x) - | Memory -> fun x -> Lwt.return (create_memory x) +let section_inner_kont : type kont a b. (kont, a, b) init_section -> a -> kont = + fun sec x -> + match sec with + | Func -> Either.Left x + | Global -> Left x + | Table -> Left x + | Memory -> Left x + +let section_inner_completed : + type kont a b. (kont, a, b) init_section -> kont -> b option = + fun sec kont -> + match (sec, kont) with + | Func, Right y -> Some y + | Global, Right y -> Some y + | Table, Right y -> Some y + | Memory, Right y -> Some y + | _ -> None + +let section_inner_step : + type kont a b. + module_inst ModuleMap.t -> + module_key -> + (kont, a, b) init_section -> + kont -> + kont Lwt.t = + fun module_reg self -> + let lift_either f = + let open Either in + function + | Left x -> + let+ y = f x in + Right y + | Right _ -> assert false + in + function + | Func -> lift_either (create_func module_reg self) + | Global -> lift_either (create_global module_reg self) + | Table -> lift_either (fun x -> Lwt.return (create_table x)) + | Memory -> lift_either (fun x -> Lwt.return (create_memory x)) -let section_update_module_ref : type a b. (a, b) init_section -> bool = function +let section_update_module_ref : type kont a b. (kont, a, b) init_section -> bool + = function | Func -> true | Global -> false | Table -> false @@ -1287,18 +1350,24 @@ let init_step ~module_reg ~self host_funcs (m : module_) (exts : extern list) = {inst0 with types = tick.destination; allocations = m.it.allocations} in update_module_ref module_reg self inst0 ; - Lwt.return (IK_Aggregate (inst0, Func, map_kont m.it.funcs)) + Lwt.return (IK_Aggregate (inst0, Func, tick_map_kont m.it.funcs)) | IK_Type (inst0, tick) -> let+ tick = map_step tick (fun x -> x.it) in IK_Type (inst0, tick) - | IK_Aggregate (inst0, sec, tick) when map_completed tick -> + | IK_Aggregate (inst0, sec, tick) when tick_map_completed tick -> Lwt.return (IK_Aggregate_concat ( inst0, sec, - concat_kont (section_fetch_vec inst0 sec) tick.destination )) + concat_kont (section_fetch_vec inst0 sec) tick.map.destination )) | IK_Aggregate (inst0, sec, tick) -> - let+ tick = map_s_step tick (section_step module_reg self sec) in + let+ tick = + tick_map_step + (section_inner_kont sec) + (section_inner_completed sec) + (section_inner_step module_reg self sec) + tick + in IK_Aggregate (inst0, sec, tick) | IK_Aggregate_concat (inst0, sec, tick) when concat_completed tick -> let inst1 = section_set_vec inst0 sec tick.res in diff --git a/src/lib_webassembly/exec/eval.mli b/src/lib_webassembly/exec/eval.mli index 2e8e09da9915c932419f3fb76a9c6033da9db5aa..ba3751a9c3fb4919c3d74c00fd9315ca63780bb0 100644 --- a/src/lib_webassembly/exec/eval.mli +++ b/src/lib_webassembly/exec/eval.mli @@ -57,11 +57,20 @@ type 'a concat_kont = { type ('a, 'b) fold_left_kont = {origin : 'a Vector.t; acc : 'b; offset : int32} -type (_, _) init_section = - | Func : (Ast.func, func_inst) init_section - | Global : (Ast.global, global_inst) init_section - | Table : (Ast.table, table_inst) init_section - | Memory : (Ast.memory, memory_inst) init_section +type (_, _, _) init_section = + | Func : ((Ast.func, func_inst) Either.t, Ast.func, func_inst) init_section + | Global + : ( (Ast.global, global_inst) Either.t, + Ast.global, + global_inst ) + init_section + | Table + : ((Ast.table, table_inst) Either.t, Ast.table, table_inst) init_section + | Memory + : ( (Ast.memory, memory_inst) Either.t, + Ast.memory, + memory_inst ) + init_section type 'b join_kont = | J_Init of 'b Vector.t Vector.t @@ -72,15 +81,20 @@ type ('a, 'b) map_concat_kont = | MC_Map of ('a, 'b Vector.t) map_kont | MC_Join of 'b join_kont +type ('kont, 'a, 'b) tick_map_kont = { + tick : 'kont option; + map : ('a, 'b) map_kont; +} + type init_kont = | IK_Start (** Very first tick of the [init] function *) | IK_Add_import of (extern, Ast.import, module_inst) fold_right2_kont | IK_Type of module_inst * (Ast.type_, Types.func_type) map_kont | IK_Aggregate : - module_inst * ('a, 'b) init_section * ('a, 'b) map_kont + module_inst * ('kont, 'a, 'b) init_section * ('kont, 'a, 'b) tick_map_kont -> init_kont | IK_Aggregate_concat : - module_inst * ('a, 'b) init_section * 'b concat_kont + module_inst * ('kont, 'a, 'b) init_section * 'b concat_kont -> init_kont | IK_Exports of module_inst * (Ast.export, extern NameMap.t) fold_left_kont | IK_Elems of module_inst * (Ast.elem_segment, elem_inst) map_kont