diff --git a/src/riscv/lib/src/state_backend/commitment_layout.rs b/src/riscv/lib/src/state_backend/commitment_layout.rs index 1c245800ab3727bb222a1a74cb174a235f7bbc91..e58f90242f9dafa66c2158516de141d4051c0391 100644 --- a/src/riscv/lib/src/state_backend/commitment_layout.rs +++ b/src/riscv/lib/src/state_backend/commitment_layout.rs @@ -17,6 +17,7 @@ use super::proof_backend::merkle::MERKLE_ARITY; use super::proof_backend::merkle::MERKLE_LEAF_SIZE; use super::proof_backend::merkle::chunks_to_writer; use crate::default::ConstDefault; +use crate::state_backend::hash::build_custom_merkle_hash; /// [`Layouts`] which may be used for commitments /// @@ -147,7 +148,11 @@ where T: CommitmentLayout, { fn state_hash(state: AllocatedOf) -> Result { - iter_state_hash::<_, T, M, LEN>(state) + let hashes: Vec = state + .into_iter() + .map(T::state_hash) + .collect::, _>>()?; + Hash::combine(&hashes) } } @@ -156,20 +161,10 @@ where T: CommitmentLayout, { fn state_hash(state: AllocatedOf) -> Result { - iter_state_hash::<_, T, M, LEN>(state) + let nodes: Vec = state + .into_iter() + .map(T::state_hash) + .collect::, _>>()?; + build_custom_merkle_hash(MERKLE_ARITY, nodes) } } - -fn iter_state_hash(iter: I) -> Result -where - M: ManagerSerialise, - I: IntoIterator>, - T: CommitmentLayout, -{ - let hashes: Vec = iter - .into_iter() - .map(T::state_hash) - .collect::, _>>()?; - - Hash::combine(&hashes) -} diff --git a/src/riscv/lib/src/state_backend/proof_backend/merkle.rs b/src/riscv/lib/src/state_backend/proof_backend/merkle.rs index fa479806b51315a3f80d0c39234103b5e7555096..53050d1123929fc7b38da9e29a542ff4ad25430e 100644 --- a/src/riscv/lib/src/state_backend/proof_backend/merkle.rs +++ b/src/riscv/lib/src/state_backend/proof_backend/merkle.rs @@ -306,22 +306,7 @@ impl MerkleWriter { self.flush_buffer()?; } - if self.leaves.is_empty() { - return Err(HashError::NonEmptyBufferExpected); - } - - let mut next_level = Vec::with_capacity(self.leaves.len().div_ceil(self.arity)); - - while self.leaves.len() > 1 { - for chunk in self.leaves.chunks(self.arity) { - next_level.push(MerkleTree::make_merkle_node(chunk.to_vec())?) - } - - std::mem::swap(&mut self.leaves, &mut next_level); - next_level.truncate(0); - } - - Ok(self.leaves[0].clone()) + build_custom_merkle_tree(self.arity, self.leaves) } } @@ -382,6 +367,37 @@ impl MerkleTree { } } +/// Build a Merkle tree whose leaves are the elements of `nodes` and in which +/// each node has the given `arity`. +pub(crate) fn build_custom_merkle_tree( + arity: usize, + mut nodes: Vec, +) -> Result { + if nodes.is_empty() { + return Err(HashError::NonEmptyBufferExpected); + } + + let mut next_level = Vec::with_capacity(nodes.len().div_ceil(arity)); + + while nodes.len() > 1 { + for chunk in nodes.chunks(arity) { + next_level.push(MerkleTree::make_merkle_node(chunk.to_vec())?) + } + + std::mem::swap(&mut nodes, &mut next_level); + next_level.truncate(0); + } + + Ok(nodes.pop().unwrap_or_else(|| { + unreachable!( + "After the loop, `nodes` could only have 0 or 1 elements. It had \ + more than 1 element at the beginning of the last iteration of the \ + loop and exactly one element was pushed to it because `nodes.chunks` \ + could not have resulted in 0 chunks for a non-empty vector." + ) + })) +} + #[cfg(test)] mod tests { use std::io::Cursor; diff --git a/src/riscv/lib/src/state_backend/proof_layout.rs b/src/riscv/lib/src/state_backend/proof_layout.rs index 9592768c3efc6cfc695653d94f21b83dad8653a9..700a2959a669d7d293185d9e1715092a3ee106f4 100644 --- a/src/riscv/lib/src/state_backend/proof_layout.rs +++ b/src/riscv/lib/src/state_backend/proof_layout.rs @@ -20,6 +20,7 @@ use super::proof_backend::merkle::MERKLE_ARITY; use super::proof_backend::merkle::MERKLE_LEAF_SIZE; use super::proof_backend::merkle::MerkleTree; use super::proof_backend::merkle::MerkleWriter; +use super::proof_backend::merkle::build_custom_merkle_tree; use super::proof_backend::merkle::chunks_to_writer; use super::proof_backend::proof::MerkleProof; use super::proof_backend::proof::MerkleProofLeaf; @@ -185,9 +186,9 @@ impl<'a> ProofTree<'a> { /// If the proof tree is absent, return absent branches and no proof hash. fn into_branches_with_hash( self, - ) -> Result<(Vec>, Option), PartialHashError> { + ) -> Result<(Box<[ProofTree<'a>; LEN]>, Option), PartialHashError> { let ProofTree::Present(proof) = self else { - return Ok((vec![ProofTree::Absent; LEN], None)); + return Ok((boxed_array![ProofTree::Absent; LEN], None)); }; match proof { @@ -198,11 +199,19 @@ impl<'a> ProofTree<'a> { }, )), Tree::Node(branches) => Ok(( - branches.iter().map(ProofTree::Present).collect::>(), + branches + .iter() + .map(ProofTree::Present) + .collect::>() + .into_boxed_slice() + .try_into() + .map_err(|_| PartialHashError::Fatal)?, None, )), Tree::Leaf(leaf) => match leaf { - MerkleProofLeaf::Blind(hash) => Ok((vec![ProofTree::Absent; LEN], Some(*hash))), + MerkleProofLeaf::Blind(hash) => { + Ok((boxed_array![ProofTree::Absent; LEN], Some(*hash))) + } _ => Err(FromProofError::UnexpectedLeaf)?, }, } @@ -356,20 +365,15 @@ impl ProofLayout for DynArray { // Expecting a branching point. // TODO RV-463: Nodes with fewer than `MERKLE_ARITY` children should also be accepted. let branches = tree.into_branches::<{ MERKLE_ARITY }>()?; - let branch_max_length = length.div_ceil(MERKLE_ARITY); - - let mut branch_start = start; - let mut length_left = length; - for branch in branches.into_iter() { - let this_branch_length = branch_max_length.min(length_left); - if this_branch_length > 0 { - pipeline.push((branch_start, this_branch_length, branch)); - } - - branch_start = branch_start.saturating_add(this_branch_length); - length_left = length_left.saturating_sub(this_branch_length); - } + push_work_items_for_branches( + start, + length, + branches.as_slice(), + |branch_start, branch_length, branch| { + pipeline.push((branch_start, branch_length, branch)); + }, + ); } } @@ -418,24 +422,14 @@ impl ProofLayout for DynArray { let (branches, proof_hash) = tree.into_branches_with_hash::<{ MERKLE_ARITY }>()?; - let branch_max_length = length.div_ceil(MERKLE_ARITY); - let mut branch_start = start; - let mut length_left = length; - - for branch in branches.into_iter() { - let this_branch_length = branch_max_length.min(length_left); - - if this_branch_length > 0 { - queue.push_back(Event::Span( - branch_start, - this_branch_length, - branch, - )); - } - - branch_start = branch_start.saturating_add(this_branch_length); - length_left = length_left.saturating_sub(this_branch_length); - } + push_work_items_for_branches( + start, + length, + branches.as_ref(), + |branch_start, branch_length, branch| { + queue.push_back(Event::Span(branch_start, branch_length, branch)); + }, + ); queue.push_back(Event::Node(proof_hash)); } @@ -683,7 +677,12 @@ where T: ProofLayout, { fn to_merkle_tree(state: RefProofGenOwnedAlloc) -> Result { - iter_to_proof::<_, T>(state) + let children = state + .into_iter() + .map(T::to_merkle_tree) + .collect::, _>>()?; + + MerkleTree::make_merkle_node(children) } fn from_proof(proof: ProofTree) -> FromProofResult { @@ -718,29 +717,124 @@ where T: ProofLayout, { fn to_merkle_tree(state: RefProofGenOwnedAlloc) -> Result { - iter_to_proof::<_, T>(state) + let leaves = state + .into_iter() + .map(T::to_merkle_tree) + .collect::, _>>()?; + + build_custom_merkle_tree(MERKLE_ARITY, leaves) } fn from_proof(proof: ProofTree) -> FromProofResult { - proof - .into_branches::()? - .iter() - .copied() - .map(T::from_proof) - .collect::, _>>() + let mut pipeline = vec![(0usize, LEN, proof)]; + + let mut data = Vec::with_capacity(LEN); + for _ in 0..LEN { + data.push(T::from_proof(ProofTree::Absent)?); + } + + while let Some((start, length, tree)) = pipeline.pop() { + if length == 1 { + data[start] = T::from_proof(tree)?; + } else { + // Expecting a branching point. + // TODO RV-463: Nodes with fewer than `MERKLE_ARITY` children should also be accepted. + let branches = tree.into_branches::<{ MERKLE_ARITY }>()?; + + push_work_items_for_branches( + start, + length, + branches.as_slice(), + |branch_start, branch_length, branch| { + pipeline.push((branch_start, branch_length, branch)); + }, + ); + } + } + Ok(data) } fn partial_state_hash( state: RefVerifierAlloc, proof: ProofTree, ) -> Result { - let (branches, proof_hash) = proof.into_branches_with_hash::()?; - let hashes = state - .into_iter() - .zip(branches.iter()) - .map(|(state, proof)| T::partial_state_hash(state, *proof)) - .collect::>>(); - combine_partial_hashes(hashes, proof_hash) + enum Event<'a> { + Span(usize, usize, ProofTree<'a>), + Node(Option), + } + + // `T::partial_state_hash` needs to take ownership of the elements of `state`. + // Given that `T` is not `Copy`, in order to take ownership of arbitrary elements + // of `state` we'd first need to duplicate it and wrap each element in a type + // which supports taking ownership. + // However, in practice, we compute the hash of each element sequentially, meaning + // that we can simply iterate over the state directly when calling `T::partial_state_hash`. + let mut state = state.into_iter(); + let mut next_vec_index = 0; + + let mut queue = VecDeque::new(); + queue.push_back(Event::Span(0usize, LEN, proof)); + + let mut hashes: Vec> = Vec::new(); + + while let Some(event) = queue.pop_front() { + match event { + Event::Span(start, length, tree) => { + if length == 1 { + // Check that iterating over the state is equivalent to calling `state[start]` + debug_assert_eq!(start, next_vec_index); + next_vec_index += 1; + hashes.push(T::partial_state_hash( + state.next().ok_or(PartialHashError::Fatal)?, + tree, + )) + } else { + // TODO RV-463: Nodes with fewer than `MERKLE_ARITY` children should also be accepted. + // The span's size is that of a node, produce `Event::Span` work items for each of its + // children and add them to the work queue, followed by an `Event::Node`. + let (branches, proof_hash) = + tree.into_branches_with_hash::<{ MERKLE_ARITY }>()?; + + push_work_items_for_branches( + start, + length, + branches.as_ref(), + |branch_start, branch_length, branch| { + queue.push_back(Event::Span(branch_start, branch_length, branch)); + }, + ); + + queue.push_back(Event::Node(proof_hash)); + } + } + Event::Node(proof_hash) => { + if hashes.is_empty() { + // The hashes which need to be combined have not yet been computed because + // their processing resulted in more `Event::Span` items. Push to the back + // of the work queue. + queue.push_back(Event::Node(proof_hash)); + continue; + } + if hashes.len() < MERKLE_ARITY { + return Err(PartialHashError::Fatal); + }; + + // Take `MERKLE_ARITY` children hashes, compute their parent's hash, and + // push it to the `hashes` stack. + let node_hashes: Vec<_> = hashes.drain(hashes.len() - MERKLE_ARITY..).collect(); + hashes.push(combine_partial_hashes(node_hashes, proof_hash)) + } + } + } + + // Check that we iterated over all the elements of the state + debug_assert_eq!(next_vec_index, LEN); + + if hashes.len() == 1 { + hashes.pop().unwrap() + } else { + Err(PartialHashError::Fatal) + } } } @@ -781,18 +875,24 @@ fn combine_partial_hashes( proof_hash.ok_or(PartialHashError::PotentiallyRecoverable) } -fn iter_to_proof<'a, 'b, I, T>(iter: I) -> Result -where - I: IntoIterator>, - T: ProofLayout, - 'b: 'a, -{ - let children = iter - .into_iter() - .map(T::to_merkle_tree) - .collect::, _>>()?; +fn push_work_items_for_branches<'a>( + mut branch_start: usize, + mut length_left: usize, + branches: &'_ [ProofTree<'a>], + mut push: impl FnMut(usize, usize, ProofTree<'a>), +) { + let branch_max_length = length_left.div_ceil(MERKLE_ARITY); + + for branch in branches.iter() { + let this_branch_length = branch_max_length.min(length_left); - MerkleTree::make_merkle_node(children) + if this_branch_length > 0 { + push(branch_start, this_branch_length, *branch); + } + + branch_start = branch_start.saturating_add(this_branch_length); + length_left = length_left.saturating_sub(this_branch_length); + } } #[cfg(test)] diff --git a/src/riscv/lib/tests/expected/dummy/state_hash_final b/src/riscv/lib/tests/expected/dummy/state_hash_final index 819403b38d8d41a26cc8e65657e34f726c812c1e..0a7fbc4540702e30fd887edc8b11731f35fdb0f4 100644 --- a/src/riscv/lib/tests/expected/dummy/state_hash_final +++ b/src/riscv/lib/tests/expected/dummy/state_hash_final @@ -1 +1 @@ -637d30afdf0c9e65dbb9c58e1ded986cfcfacd58b89f9bf4f3566bf0f7efbcef +540906319b1fdc92677d4bc9f56637defbbd22f29a8e9915bab7ba02853a3915 diff --git a/src/riscv/lib/tests/expected/dummy/state_hash_initial b/src/riscv/lib/tests/expected/dummy/state_hash_initial index 44d789a2fce17e8f49a561021798031b4a6e589b..4b79fe97953ddfb4fe0a4c8b472e64d478b99e7f 100644 --- a/src/riscv/lib/tests/expected/dummy/state_hash_initial +++ b/src/riscv/lib/tests/expected/dummy/state_hash_initial @@ -1 +1 @@ -f7c75ecd81076935b46b240c2e159e648b99e63677abaebedbb19c2ee21008e5 +47c5121565670ac039a85ddddad2229a0c146467f114fc01c49de8df36bc75cd diff --git a/src/riscv/lib/tests/expected/jstz/state_hash_final b/src/riscv/lib/tests/expected/jstz/state_hash_final index 74ca032767bf334878721484d13e6b67643c6ece..13b93f3b660b700354a3dd0e5d130d963efee8e0 100644 --- a/src/riscv/lib/tests/expected/jstz/state_hash_final +++ b/src/riscv/lib/tests/expected/jstz/state_hash_final @@ -1 +1 @@ -f45f769187a482e10f435bd075924cf74067a737c951b475ebfef82dffa3e08d +c8919f548c0441164597ad77e242d9468e08e624883f5832c89e9bf7c2ee96a4 diff --git a/src/riscv/lib/tests/expected/jstz/state_hash_initial b/src/riscv/lib/tests/expected/jstz/state_hash_initial index e0861cd5e1e69e9bc4d4ac16e2b828a5f3f4d6e4..810d89e0ff3d09ff1f264f41221f22f15a429846 100644 --- a/src/riscv/lib/tests/expected/jstz/state_hash_initial +++ b/src/riscv/lib/tests/expected/jstz/state_hash_initial @@ -1 +1 @@ -ef98cc742841825fd13bcbf2f241d727872267362881ba32dab15abfdba49cca +5335b927ee0f5a648200e8d03bf63e438cf5e120f2b3b2bb891804004ba25129 diff --git a/src/riscv/lib/tests/test_proofs.rs b/src/riscv/lib/tests/test_proofs.rs index 20bd5fd11ed4521417b919d48834bf09dfd2f09f..0b08aca3f1eb984bb511d8b0fd3fcc5904764625 100644 --- a/src/riscv/lib/tests/test_proofs.rs +++ b/src/riscv/lib/tests/test_proofs.rs @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2024 Nomadic Labs +// SPDX-FileCopyrightText: 2024-2025 Nomadic Labs // SPDX-FileCopyrightText: 2024 TriliTech // // SPDX-License-Identifier: MIT @@ -6,11 +6,13 @@ mod common; use std::ops::Bound; +use std::time::Instant; use common::*; use octez_riscv::machine_state::DefaultCacheLayouts; use octez_riscv::machine_state::memory::M64M; use octez_riscv::state_backend::hash; +use octez_riscv::state_backend::proof_backend::proof::serialise_proof; use octez_riscv::stepper::Stepper; use octez_riscv::stepper::StepperStatus; use octez_riscv::stepper::pvm::PvmStepper; @@ -70,7 +72,15 @@ where assert!(matches!(result, StepperStatus::Running { .. })); eprintln!("> Producing proof ..."); + let start = Instant::now(); let proof = stepper.produce_proof().unwrap(); + let time = start.elapsed(); + let serialisation: Vec = serialise_proof(&proof).collect(); + eprintln!( + "> Proof of size {} KiB produced in {:?}", + serialisation.len() / 1024, + time + ); eprintln!("> Checking initial proof hash ..."); assert_eq!(proof.initial_state_hash(), stepper.hash());