diff --git a/src/lib_benchmark/codegen.ml b/src/lib_benchmark/codegen.ml index ccda7170317ac9a6d548b478427f7b3c6d7213c3..2378d4dd8073652ed14c0f5d79d268b40b75c4f4 100644 --- a/src/lib_benchmark/codegen.ml +++ b/src/lib_benchmark/codegen.ml @@ -133,20 +133,10 @@ module Lift_then_print = Costlang.Let_lift (Codegen) type solution = float Free_variable.Map.t let load_solution (fn : string) : solution = - let infile = open_in fn in - try - let res = Marshal.from_channel infile in - close_in infile ; - res - with exn -> - close_in infile ; - Format.eprintf "Codegen.load_solution: could not load %s@." fn ; - raise exn + In_channel.with_open_bin fn Marshal.from_channel let save_solution (s : solution) (fn : string) = - let outfile = open_out fn in - Marshal.to_channel outfile s [] ; - close_out outfile + Out_channel.with_open_bin fn @@ fun outfile -> Marshal.to_channel outfile s [] (* ------------------------------------------------------------------------- *) diff --git a/src/lib_benchmark/csv.ml b/src/lib_benchmark/csv.ml index af998fc60705ccf4ef7396d2f969a573c538d916..8fd58ab95c8decacc5cbf827808258486814e9e6 100644 --- a/src/lib_benchmark/csv.ml +++ b/src/lib_benchmark/csv.ml @@ -31,20 +31,33 @@ let all_equal (l : int list) = in match l with [] -> true | hd :: tl -> loop tl hd +module String_set = Set.Make (String) + +let disjoint_headers (csv1 : csv) (csv2 : csv) = + match (csv1, csv2) with + | [], _ | _, [] -> true + | header1 :: _, header2 :: _ -> + let header1 = String_set.of_list header1 in + let header2 = String_set.of_list header2 in + String_set.disjoint header1 header2 + (* Horizontally concat CSVs *) -let concat (csv1 : csv) (csv2 : csv) : csv = +let concat ?(check_disjoint_headers = true) (csv1 : csv) (csv2 : csv) : csv = (* Check that both CSVs have the same number of lines. *) if Compare.List_lengths.(csv1 <> csv2) then Stdlib.failwith "Csv.concat: CSVs have different length" else - (* Check that each CSV has the same number of *) + (* Check that each line has the same number of columns *) let lengths1 = List.map List.length csv1 in - let lengths2 = List.map List.length csv1 in + let lengths2 = List.map List.length csv2 in if not (all_equal lengths1) then let msg = "Csv.concat: first argument has uneven # of lines" in Stdlib.failwith msg else if not (all_equal lengths2) then - let msg = "Csv.concat: first argument has uneven # of lines" in + let msg = "Csv.concat: second argument has uneven # of lines" in + Stdlib.failwith msg + else if check_disjoint_headers && not (disjoint_headers csv1 csv2) then + let msg = "Csv.concat: headers are not disjoint" in Stdlib.failwith msg else (* see top if condition *) @@ -57,7 +70,7 @@ let concat (csv1 : csv) (csv2 : csv) : csv = let export ~filename ?(separator = ',') ?(linebreak = '\n') (data : csv) = Format.eprintf "Exporting to %s@." filename ; let sep_str = String.make 1 separator in - let outfile = open_out filename in + Out_channel.with_open_text filename @@ fun outfile -> let fmtr = Format.formatter_of_out_channel outfile in List.iter (fun line -> @@ -66,20 +79,14 @@ let export ~filename ?(separator = ',') ?(linebreak = '\n') (data : csv) = | _ -> let s = String.concat sep_str line in Format.fprintf fmtr "%s%c@?" s linebreak) - data ; - close_out outfile + data -(* shamelessly stolen from - https://stackoverflow.com/questions/5774934/how-do-i-read-in-lines-from-a-text-file-in-ocaml *) let read_lines name : string list = - let ic = open_in name in - let try_read () = try Some (input_line ic) with End_of_file -> None in + In_channel.with_open_text name @@ fun ic -> let rec loop acc = - match try_read () with + match In_channel.input_line ic with | Some s -> loop (s :: acc) - | None -> - close_in ic ; - List.rev acc + | None -> List.rev acc in loop []