From 8c5c03acd0fa0ebd0f083ae25f9bec5335933107 Mon Sep 17 00:00:00 2001 From: Thomas Letan Date: Fri, 26 Aug 2022 12:04:32 +0200 Subject: [PATCH] WASM: Going deeper into the rabbit hole of tickifying Eval.init Some sections of 'Eval.init' like export and global need deeper tickification, meaning their ticks remain too large for now. We prepare this tickification by first extending the 'init_section' type. --- src/lib_scoru_wasm/init_encodings.ml | 48 ++++++--- src/lib_tree_encoding/tree_encoding.ml | 16 +++ src/lib_tree_encoding/tree_encoding.mli | 4 + src/lib_webassembly/exec/eval.ml | 123 ++++++++++++++++++------ src/lib_webassembly/exec/eval.mli | 28 ++++-- 5 files changed, 172 insertions(+), 47 deletions(-) diff --git a/src/lib_scoru_wasm/init_encodings.ml b/src/lib_scoru_wasm/init_encodings.ml index d061f328f8ea..89abb29e60f1 100644 --- a/src/lib_scoru_wasm/init_encodings.ml +++ b/src/lib_scoru_wasm/init_encodings.ml @@ -50,6 +50,13 @@ let map_kont_encoding enc_a enc_b = (scope ["destination"] enc_b) (value ["offset"] Data_encoding.int32) +let tick_map_kont_encoding enc_kont enc_a enc_b = + conv (fun (tick, map) -> {tick; map}) (fun {tick; map} -> (tick, map)) + @@ tup2 + ~flatten:true + (option (scope ["inner_kont"] enc_kont)) + (scope ["map_kont"] (map_kont_encoding enc_a enc_b)) + let concat_kont_encoding enc_a = conv (fun (lv, rv, res, offset) -> {lv; rv; res; offset}) @@ -76,10 +83,10 @@ let lazy_vec_encoding enc = int32_lazy_vector (value [] Data_encoding.int32) enc type (_, _) eq = Eq : ('a, 'a) eq let init_section_eq : - type a b c d. - (a, b) init_section -> - (c, d) init_section -> - ((a, b) init_section, (c, d) init_section) eq option = + type kont kont' a b c d. + (kont, a, b) init_section -> + (kont', c, d) init_section -> + ((kont, a, b) init_section, (kont', c, d) init_section) eq option = fun sec1 sec2 -> match (sec1, sec2) with | Func, Func -> Some Eq @@ -89,10 +96,14 @@ let init_section_eq : | _, _ -> None let aggregate_cases : - type a b. - string -> (a, b) init_section -> a t -> b t -> (string, init_kont) case list - = - fun name sec enc_a enc_b -> + type kont a b. + string -> + (kont, a, b) init_section -> + kont t -> + a t -> + b t -> + (string, init_kont) case list = + fun name sec enc_kont enc_a enc_b -> [ case Format.(sprintf "IK_Aggregate_%s" name) @@ -101,7 +112,8 @@ let aggregate_cases : (scope ["module"] Wasm_encoding.module_instance_encoding) (scope ["kont"] - (map_kont_encoding + (tick_map_kont_encoding + enc_kont (lazy_vec_encoding enc_a) (lazy_vec_encoding enc_b)))) (function @@ -126,6 +138,16 @@ let aggregate_cases : (function m, t -> IK_Aggregate_concat (m, sec, t)); ] +let aggregate_cases_either : + type a b. + string -> + ((a, b) Either.t, a, b) init_section -> + a t -> + b t -> + (string, init_kont) case list = + fun name sec enc_a enc_b -> + aggregate_cases name sec (either enc_a enc_b) enc_a enc_b + let join_kont_encoding enc_b = tagged_union tag_encoding @@ -200,22 +222,22 @@ let init_kont_encoding ~host_funcs = (function IK_Type (m, t) -> Some (m, t) | _ -> None) (function m, t -> IK_Type (m, t)); ] - @ aggregate_cases + @ aggregate_cases_either "func" Func Parser.Code.func_encoding Wasm_encoding.function_encoding - @ aggregate_cases + @ aggregate_cases_either "global" Global (value [] Interpreter_encodings.Ast.global_encoding) Wasm_encoding.global_encoding - @ aggregate_cases + @ aggregate_cases_either "table" Table (value [] Interpreter_encodings.Ast.table_encoding) Wasm_encoding.table_encoding - @ aggregate_cases + @ aggregate_cases_either "memory" Memory (value [] Interpreter_encodings.Ast.memory_encoding) diff --git a/src/lib_tree_encoding/tree_encoding.ml b/src/lib_tree_encoding/tree_encoding.ml index d0df478e14de..d872f4ba8c73 100644 --- a/src/lib_tree_encoding/tree_encoding.ml +++ b/src/lib_tree_encoding/tree_encoding.ml @@ -327,6 +327,22 @@ let delayed f = in {encode; decode} +let either enc_a enc_b = + tagged_union + (value [] Data_encoding.string) + [ + case + "Left" + enc_a + (function Either.Left x -> Some x | _ -> None) + (function x -> Left x); + case + "Right" + enc_b + (function Either.Right x -> Some x | _ -> None) + (function x -> Right x); + ] + module Runner = struct module type TREE = S diff --git a/src/lib_tree_encoding/tree_encoding.mli b/src/lib_tree_encoding/tree_encoding.mli index 80a54a7ae595..f5071f5f98f1 100644 --- a/src/lib_tree_encoding/tree_encoding.mli +++ b/src/lib_tree_encoding/tree_encoding.mli @@ -234,6 +234,10 @@ val option : 'a t -> 'a option t to allow for directly recursive encoders/decoders. *) val delayed : (unit -> 'a t) -> 'a t +(** [either enc_a enc_b] returns an encoder where [enc_a] is used for + the left case of [Either.t], and [enc_b] for the [Right] case. *) +val either : 'a t -> 'b t -> ('a, 'b) Either.t t + module Runner : sig module type TREE = sig type tree diff --git a/src/lib_webassembly/exec/eval.ml b/src/lib_webassembly/exec/eval.ml index 77e39101acde..56239cabff20 100644 --- a/src/lib_webassembly/exec/eval.ml +++ b/src/lib_webassembly/exec/eval.ml @@ -1131,14 +1131,44 @@ let fold_left_s_step {origin; acc; offset} f = let+ acc = f acc x in {origin; acc; offset = Int32.succ offset} -type (_, _) init_section = - | Func : (func, func_inst) init_section - | Global : (global, global_inst) init_section - | Table : (table, table_inst) init_section - | Memory : (memory, memory_inst) init_section +type ('kont, 'a, 'b) tick_map_kont = { + tick : 'kont option; + map : ('a, 'b) map_kont; +} + +let tick_map_completed {map; _} = map_completed map + +let tick_map_kont v = {tick = None; map = map_kont v} + +let tick_map_step first_kont kont_completed kont_step = function + | {map; _} when map_completed map -> assert false + | {tick = None; map} -> + let+ x = Vector.get map.offset map.origin in + let tick = first_kont x in + {tick = Some tick; map} + | {tick = Some tick; map} -> ( + match kont_completed tick with + | Some v -> + let map = + { + map with + destination = Vector.set map.offset v map.destination; + offset = Int32.succ map.offset; + } + in + Lwt.return {tick = None; map} + | None -> + let+ tick = kont_step tick in + {tick = Some tick; map}) + +type (_, _, _) init_section = + | Func : ((func, func_inst) Either.t, func, func_inst) init_section + | Global : ((global, global_inst) Either.t, global, global_inst) init_section + | Table : ((table, table_inst) Either.t, table, table_inst) init_section + | Memory : ((memory, memory_inst) Either.t, memory, memory_inst) init_section let section_fetch_vec : - type a b. module_inst -> (a, b) init_section -> b Vector.t = + type kont a b. module_inst -> (kont, a, b) init_section -> b Vector.t = fun inst sec -> match sec with | Func -> inst.funcs @@ -1147,7 +1177,8 @@ let section_fetch_vec : | Memory -> inst.memories let section_set_vec : - type a b. module_inst -> (a, b) init_section -> b Vector.t -> module_inst = + type kont a b. + module_inst -> (kont, a, b) init_section -> b Vector.t -> module_inst = fun inst sec vec -> match (sec, vec) with | Func, funcs -> {inst with funcs} @@ -1220,10 +1251,10 @@ type init_kont = | IK_Add_import of (extern, import, module_inst) fold_right2_kont | IK_Type of module_inst * (type_, func_type) map_kont | IK_Aggregate : - module_inst * ('a, 'b) init_section * ('a, 'b) map_kont + module_inst * ('kont, 'a, 'b) init_section * ('kont, 'a, 'b) tick_map_kont -> init_kont | IK_Aggregate_concat : - module_inst * ('a, 'b) init_section * 'b concat_kont + module_inst * ('kont, 'a, 'b) init_section * 'b concat_kont -> init_kont | IK_Exports of module_inst * (export, extern NameMap.t) fold_left_kont | IK_Elems of module_inst * (elem_segment, elem_inst) map_kont @@ -1238,25 +1269,57 @@ type init_kont = | IK_Stop of module_inst let section_next_init_kont : - type a b. module_ -> module_inst -> (a, b) init_section -> init_kont = + type kont a b. + module_ -> module_inst -> (kont, a, b) init_section -> init_kont = fun m inst0 sec -> match sec with - | Func -> IK_Aggregate (inst0, Global, map_kont m.it.globals) - | Global -> IK_Aggregate (inst0, Table, map_kont m.it.tables) - | Table -> IK_Aggregate (inst0, Memory, map_kont m.it.memories) + | Func -> IK_Aggregate (inst0, Global, tick_map_kont m.it.globals) + | Global -> IK_Aggregate (inst0, Table, tick_map_kont m.it.tables) + | Table -> IK_Aggregate (inst0, Memory, tick_map_kont m.it.memories) | Memory -> IK_Exports (inst0, fold_left_kont m.it.exports (NameMap.create ())) -let section_step : - type a b. - module_inst ModuleMap.t -> module_key -> (a, b) init_section -> a -> b Lwt.t - = - fun module_reg self -> function - | Func -> create_func module_reg self - | Global -> create_global module_reg self - | Table -> fun x -> Lwt.return (create_table x) - | Memory -> fun x -> Lwt.return (create_memory x) +let section_inner_kont : type kont a b. (kont, a, b) init_section -> a -> kont = + fun sec x -> + match sec with + | Func -> Either.Left x + | Global -> Left x + | Table -> Left x + | Memory -> Left x + +let section_inner_completed : + type kont a b. (kont, a, b) init_section -> kont -> b option = + fun sec kont -> + match (sec, kont) with + | Func, Right y -> Some y + | Global, Right y -> Some y + | Table, Right y -> Some y + | Memory, Right y -> Some y + | _ -> None + +let section_inner_step : + type kont a b. + module_inst ModuleMap.t -> + module_key -> + (kont, a, b) init_section -> + kont -> + kont Lwt.t = + fun module_reg self -> + let lift_either f = + let open Either in + function + | Left x -> + let+ y = f x in + Right y + | Right _ -> assert false + in + function + | Func -> lift_either (create_func module_reg self) + | Global -> lift_either (create_global module_reg self) + | Table -> lift_either (fun x -> Lwt.return (create_table x)) + | Memory -> lift_either (fun x -> Lwt.return (create_memory x)) -let section_update_module_ref : type a b. (a, b) init_section -> bool = function +let section_update_module_ref : type kont a b. (kont, a, b) init_section -> bool + = function | Func -> true | Global -> false | Table -> false @@ -1287,18 +1350,24 @@ let init_step ~module_reg ~self host_funcs (m : module_) (exts : extern list) = {inst0 with types = tick.destination; allocations = m.it.allocations} in update_module_ref module_reg self inst0 ; - Lwt.return (IK_Aggregate (inst0, Func, map_kont m.it.funcs)) + Lwt.return (IK_Aggregate (inst0, Func, tick_map_kont m.it.funcs)) | IK_Type (inst0, tick) -> let+ tick = map_step tick (fun x -> x.it) in IK_Type (inst0, tick) - | IK_Aggregate (inst0, sec, tick) when map_completed tick -> + | IK_Aggregate (inst0, sec, tick) when tick_map_completed tick -> Lwt.return (IK_Aggregate_concat ( inst0, sec, - concat_kont (section_fetch_vec inst0 sec) tick.destination )) + concat_kont (section_fetch_vec inst0 sec) tick.map.destination )) | IK_Aggregate (inst0, sec, tick) -> - let+ tick = map_s_step tick (section_step module_reg self sec) in + let+ tick = + tick_map_step + (section_inner_kont sec) + (section_inner_completed sec) + (section_inner_step module_reg self sec) + tick + in IK_Aggregate (inst0, sec, tick) | IK_Aggregate_concat (inst0, sec, tick) when concat_completed tick -> let inst1 = section_set_vec inst0 sec tick.res in diff --git a/src/lib_webassembly/exec/eval.mli b/src/lib_webassembly/exec/eval.mli index 2e8e09da9915..ba3751a9c3fb 100644 --- a/src/lib_webassembly/exec/eval.mli +++ b/src/lib_webassembly/exec/eval.mli @@ -57,11 +57,20 @@ type 'a concat_kont = { type ('a, 'b) fold_left_kont = {origin : 'a Vector.t; acc : 'b; offset : int32} -type (_, _) init_section = - | Func : (Ast.func, func_inst) init_section - | Global : (Ast.global, global_inst) init_section - | Table : (Ast.table, table_inst) init_section - | Memory : (Ast.memory, memory_inst) init_section +type (_, _, _) init_section = + | Func : ((Ast.func, func_inst) Either.t, Ast.func, func_inst) init_section + | Global + : ( (Ast.global, global_inst) Either.t, + Ast.global, + global_inst ) + init_section + | Table + : ((Ast.table, table_inst) Either.t, Ast.table, table_inst) init_section + | Memory + : ( (Ast.memory, memory_inst) Either.t, + Ast.memory, + memory_inst ) + init_section type 'b join_kont = | J_Init of 'b Vector.t Vector.t @@ -72,15 +81,20 @@ type ('a, 'b) map_concat_kont = | MC_Map of ('a, 'b Vector.t) map_kont | MC_Join of 'b join_kont +type ('kont, 'a, 'b) tick_map_kont = { + tick : 'kont option; + map : ('a, 'b) map_kont; +} + type init_kont = | IK_Start (** Very first tick of the [init] function *) | IK_Add_import of (extern, Ast.import, module_inst) fold_right2_kont | IK_Type of module_inst * (Ast.type_, Types.func_type) map_kont | IK_Aggregate : - module_inst * ('a, 'b) init_section * ('a, 'b) map_kont + module_inst * ('kont, 'a, 'b) init_section * ('kont, 'a, 'b) tick_map_kont -> init_kont | IK_Aggregate_concat : - module_inst * ('a, 'b) init_section * 'b concat_kont + module_inst * ('kont, 'a, 'b) init_section * 'b concat_kont -> init_kont | IK_Exports of module_inst * (Ast.export, extern NameMap.t) fold_left_kont | IK_Elems of module_inst * (Ast.elem_segment, elem_inst) map_kont -- GitLab