diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..adcf69b35 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -98,7 +98,7 @@ pub fn merkle_verify bool where - F: Default + Copy + PartialEq, + F: field::PrimeCharacteristicRing + PartialEq, Comp: Compression<[F; WIDTH]>, { if opening_proof.len() != log_height { diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index ebea80a9e..da3d4ee12 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -1,44 +1,44 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). use crate::Compression; +use field::PrimeCharacteristicRing; -// IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). Maybe we should re-add IV all the time for simplicity? -// assumes data length is a multiple of RATE (= 8 in practice). +/// Absorbs `data` RTL into an IV state `[data.len(), 0, ..., 0]` in RATE-sized chunks. +/// assumes data length is a multiple of RATE (= 8 in practice). pub fn hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] where - T: Default + Copy, + T: PrimeCharacteristicRing, Comp: Compression<[T; WIDTH]>, { debug_assert!(RATE == OUT); debug_assert!(WIDTH == OUT + RATE); debug_assert!(data.len().is_multiple_of(RATE)); - let n_chunks = data.len() / RATE; - debug_assert!(n_chunks >= 2); - let mut state: [T; WIDTH] = data[data.len() - WIDTH..].try_into().unwrap(); - comp.compress_mut(&mut state); - for chunk_idx in (0..n_chunks - 2).rev() { - let offset = chunk_idx * RATE; - state[WIDTH - RATE..].copy_from_slice(&data[offset..offset + RATE]); + let mut state = [T::default(); WIDTH]; + state[0] = T::from_usize(data.len()); + for chunk in data.chunks_exact(RATE).rev() { + state[WIDTH - RATE..].copy_from_slice(chunk); comp.compress_mut(&mut state); } state[..OUT].try_into().unwrap() } -/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks. +/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks +/// into an IV state `[iv_first, 0, ..., 0]`. Caller provides `iv_first` (typically +/// the length, in field elements, of the full slice that will eventually be hashed). pub fn precompute_zero_suffix_state( comp: &Comp, + iv_first: T, n_zero_chunks: usize, ) -> [T; WIDTH] where - T: Default + Copy, + T: PrimeCharacteristicRing, Comp: Compression<[T; WIDTH]>, { debug_assert!(RATE == OUT); debug_assert!(WIDTH == OUT + RATE); - debug_assert!(n_zero_chunks >= 2); let mut state = [T::default(); WIDTH]; - comp.compress_mut(&mut state); - for _ in 0..n_zero_chunks - 2 { + state[0] = iv_first; + for _ in 0..n_zero_chunks { for s in &mut state[WIDTH - RATE..] { *s = T::default(); } @@ -47,29 +47,7 @@ where state } -/// RTL = Right-to-left -#[inline(always)] -pub fn hash_rtl_iter( - comp: &Comp, - rtl_iter: I, -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: IntoIterator, -{ - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); - let mut state = [T::default(); WIDTH]; - let mut iter = rtl_iter.into_iter(); - for pos in (0..WIDTH).rev() { - state[pos] = iter.next().unwrap(); - } - comp.compress_mut(&mut state); - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left +/// RTL = Right-to-left. Absorbs starting from the provided `initial_state` in RATE-sized chunks. #[inline(always)] pub fn hash_rtl_iter_with_initial_state( comp: &Comp, @@ -81,28 +59,15 @@ where Comp: Compression<[T; WIDTH]>, I: Iterator, { + debug_assert!(RATE == OUT); + debug_assert!(WIDTH == OUT + RATE); let mut state = *initial_state; - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left -#[inline(always)] -fn absorb_rtl_chunks( - comp: &Comp, - state: &mut [T; WIDTH], - iter: &mut I, -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: Iterator, -{ while let Some(elem) = iter.next() { state[WIDTH - 1] = elem; for pos in (WIDTH - RATE..WIDTH - 1).rev() { state[pos] = iter.next().unwrap(); } - comp.compress_mut(state); + comp.compress_mut(&mut state); } state[..OUT].try_into().unwrap() } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 4bc9c316d..1812197d5 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -161,7 +161,7 @@ pub fn compile_to_low_level_bytecode( pc_to_location.push(current_location); } - let hash = poseidon_compress_slice(&instructions_multilinear, true); + let hash = poseidon_compress_slice(&instructions_multilinear); let code: Vec<_> = instructions .into_iter() diff --git a/crates/rec_aggregation/src/bytecode_claims.rs b/crates/rec_aggregation/src/bytecode_claims.rs index 8e001fab4..148d82aea 100644 --- a/crates/rec_aggregation/src/bytecode_claims.rs +++ b/crates/rec_aggregation/src/bytecode_claims.rs @@ -124,7 +124,7 @@ pub(crate) fn hash_bytecode_claims(claims: &[Evaluation]) -> [F; DIGEST_LEN] let mut data = flatten_scalars_to_base::(&ef_data); data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); - let claim_hash = poseidon_compress_slice(&data, false); + let claim_hash = poseidon_compress_slice(&data); running_hash = poseidon16_compress_pair(&running_hash, &claim_hash); } running_hash diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 30e678103..80db685d3 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -28,7 +28,7 @@ pub struct InnerVerified { } pub(crate) fn verify_inner(input_data: Vec, proof: Proof) -> Result { - let input_data_hash = poseidon_compress_slice(&input_data, true); + let input_data_hash = poseidon_compress_slice(&input_data); let bytecode = get_aggregation_bytecode(); let (verif, raw_proof) = verify_execution(bytecode, &input_data_hash, proof)?; Ok(InnerVerified { diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/type_1_aggregation.rs index 4ca4dc16c..50f2ad88b 100644 --- a/crates/rec_aggregation/src/type_1_aggregation.rs +++ b/crates/rec_aggregation/src/type_1_aggregation.rs @@ -31,8 +31,6 @@ pub(crate) const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; pub(crate) const TWEAK_SLOT_SIZE: usize = 4; pub(crate) const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE).next_multiple_of(DIGEST_LEN); -pub(crate) const TWEAKS_HASHING_USE_IV: bool = false; // fixed size → no IV needed - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub(crate) struct Digest(pub [F; DIGEST_LEN]); @@ -100,7 +98,7 @@ impl TypeOneInfo { pub(crate) fn build_input_data(&self) -> Vec { let tweak_table = compute_tweak_table(self.slot); - let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); + let tweaks_hash = poseidon_compress_slice(&tweak_table); build_type1_input_data( self.pubkeys.len(), &hash_pubkeys(&self.pubkeys), @@ -115,7 +113,7 @@ impl TypeOneInfo { pub(crate) fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] { let flat: Vec = pub_keys.iter().flat_map(|pk| pk.flaten().into_iter()).collect(); - poseidon_compress_slice(&flat, true) + poseidon_compress_slice(&flat) } /// Tweak slots are 4-FE [tw[0], tw[1], 0, 0] @@ -262,7 +260,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( assert!(n_sigs <= MAX_XMSS_AGGREGATED); let tweak_table = compute_tweak_table(slot); - let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); + let tweaks_hash = poseidon_compress_slice(&tweak_table); let reduced_claims = reduce_bytecode_claims(&verified_children); @@ -275,7 +273,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( &reduced_claims.final_claim_flat(), bytecode, ); - let public_input = poseidon_compress_slice(&pub_input_data, true).to_vec(); + let public_input = poseidon_compress_slice(&pub_input_data).to_vec(); let mut claimed: HashSet = HashSet::new(); let mut dup_pub_keys: Vec = Vec::new(); diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs index 9f3fee572..39c2c04ab 100644 --- a/crates/rec_aggregation/src/type_2_aggregation.rs +++ b/crates/rec_aggregation/src/type_2_aggregation.rs @@ -110,7 +110,7 @@ pub fn merge_many_type_1( let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect(); let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat()); - let public_input_digest = poseidon_compress_slice(&pub_input_data, true).to_vec(); + let public_input_digest = poseidon_compress_slice(&pub_input_data).to_vec(); let bytecode_value_hint_blobs: Vec> = verified_children .iter() @@ -173,7 +173,7 @@ pub fn verify_type_2(sig: &TypeTwoMultiSignature) -> Result>(); let input_data = build_type2_input_data(&digests, &sig.bytecode_claim_flat()); verify_inner(input_data, sig.proof.proof.clone()) @@ -222,7 +222,7 @@ pub fn split_type_2( let mut outer_type_1 = type_2.info[index].clone(); outer_type_1.bytecode_claim = reduced_claims.final_claim.clone(); let ourer_input_data = outer_type_1.build_input_data(); - let outer_digest = poseidon_compress_slice(&ourer_input_data, true); + let outer_digest = poseidon_compress_slice(&ourer_input_data); let inner_input_data: Vec = type_2.info[index].build_input_data(); diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index d2414cf6b..7e740024b 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -16,7 +16,7 @@ PREAMBLE_MEMORY_LEN = PREAMBLE_MEMORY_END - PUBLIC_INPUT_LEN -def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks): +def batch_hash_slice_rtl_with_iv(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks): if num_chunks == DIM * 2: batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2) return @@ -43,39 +43,49 @@ def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, nu def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const): + iv = build_iv(num_chunks * DIGEST_LEN) for i in range(0, num_queries): data = all_data_to_hash[i] - res = slice_hash_rtl(data, num_chunks) + res = slice_hash_rtl(data, num_chunks, iv) all_resulting_hashes[i] = res return +# IV for the sponge: [slice length in field elements, 0, 0, ..., 0] @inline -def slice_hash_rtl(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) +def build_iv(length): + iv = Array(DIGEST_LEN) + iv[0] = length + for k in unroll(1, DIGEST_LEN): + iv[k] = 0 + return iv - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): + +@inline +def slice_hash_rtl(data, num_chunks, iv): + debug_assert(1 <= num_chunks) + states = Array(num_chunks * DIGEST_LEN) + poseidon16_compress(iv, data + (num_chunks - 1) * DIGEST_LEN, states) + for j in unroll(1, num_chunks): poseidon16_compress( - states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN + states + (j - 1) * DIGEST_LEN, data + (num_chunks - 1 - j) * DIGEST_LEN, states + j * DIGEST_LEN ) - return states + (num_chunks - 2) * DIGEST_LEN + return states + (num_chunks - 1) * DIGEST_LEN @inline -def slice_hash(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) - poseidon16_compress(data, data + DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (j + 1) * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 2) * DIGEST_LEN +def slice_hash_ret(data, num_chunks): + res = Array(DIGEST_LEN) + slice_hash(data, num_chunks, res) + return res -def slice_hash_with_iv_range(data, num_chunks, dest): +def slice_hash_range(data, num_chunks, dest): debug_assert(0 < num_chunks) debug_assert(2 < num_chunks) + iv = build_iv(num_chunks * DIGEST_LEN) states = Array((num_chunks - 1) * DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, states) + poseidon16_compress(iv, data, states) for j in range(1, num_chunks - 1): poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) poseidon16_compress(states + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, dest) @@ -83,27 +93,30 @@ def slice_hash_with_iv_range(data, num_chunks, dest): @inline -def slice_hash_with_iv(data, num_chunks, dest): +def slice_hash(data, num_chunks, dest): debug_assert(2 <= num_chunks) + iv = build_iv(num_chunks * DIGEST_LEN) states = Array(num_chunks * DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, states) + poseidon16_compress(iv, data, states) for j in unroll(1, num_chunks - 1): poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) poseidon16_compress(states + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, dest) return -def slice_hash_with_iv_dynamic_unroll(data, num_chunks, num_chunks_bits: Const): +def slice_hash_dynamic_unroll(data, num_chunks, num_chunks_bits: Const): debug_assert(num_chunks != 0) debug_assert(num_chunks < 2**num_chunks_bits) + iv = build_iv(num_chunks * DIGEST_LEN) + if num_chunks == 1: result = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, result) + poseidon16_compress(iv, data, result) return result states = Array(num_chunks * DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, states) + poseidon16_compress(iv, data, states) n_iters = num_chunks - 1 state_ptr: Mut = states data_ptr: Mut = data + DIGEST_LEN diff --git a/crates/rec_aggregation/zkdsl_implem/main.py b/crates/rec_aggregation/zkdsl_implem/main.py index e5e9ca253..9f071db60 100644 --- a/crates/rec_aggregation/zkdsl_implem/main.py +++ b/crates/rec_aggregation/zkdsl_implem/main.py @@ -61,14 +61,14 @@ def main(): inner_type1_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) hint_witness("component_layout", inner_type1_buf) ensure_well_formed_input_data(inner_type1_buf, bytecode_hash_domsep, TYPE_1_FLAG) - slice_hash_with_iv(inner_type1_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, component_digest) + slice_hash(inner_type1_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, component_digest) bytecode_claims[2 * c] = inner_type1_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[2 * c + 1] = recursion(component_digest, bytecode_hash_domsep) reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output) - slice_hash_with_iv_range(data_buf, n_components + TYPE_2_BASE_NUM_CHUNKS, pub_mem) + slice_hash_range(data_buf, n_components + TYPE_2_BASE_NUM_CHUNKS, pub_mem) return assert discriminator == TYPE_1_FLAG @@ -97,15 +97,15 @@ def main(): copy_32(data_buf + COMPONENT_DATA_OFFSET, kept_type1_buff + COMPONENT_DATA_OFFSET) ensure_well_formed_input_data(kept_type1_buff, bytecode_hash_domsep, TYPE_1_FLAG) digest_kept = type2_digests + type2_kept_index * DIGEST_LEN - slice_hash_with_iv(kept_type1_buff, TYPE_1_INPUT_DATA_NUM_CHUNKS, digest_kept) + slice_hash(kept_type1_buff, TYPE_1_INPUT_DATA_NUM_CHUNKS, digest_kept) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - slice_hash_with_iv_range(type2_data_buf, type2_num_chunks, inner_pub_mem) + slice_hash_range(type2_data_buf, type2_num_chunks, inner_pub_mem) bytecode_claims = Array(2) bytecode_claims[0] = type2_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[1] = recursion(inner_pub_mem, bytecode_hash_domsep) reduce_bytecode_claims(bytecode_claims, 2, bytecode_claim_output) - slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) + slice_hash(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) return # ============ Standard type-1: single (message, slot) aggregation ============ @@ -139,7 +139,7 @@ def main(): aggregate_sizes = Array(n_recursions) hint_witness("aggregate_sizes", aggregate_sizes) - computed_tweaks_hash = slice_hash(tweak_table, TWEAK_TABLE_SIZE_FE_PADDED / DIGEST_LEN) + computed_tweaks_hash = slice_hash_ret(tweak_table, TWEAK_TABLE_SIZE_FE_PADDED / DIGEST_LEN) copy_8(computed_tweaks_hash, tweaks_hash_expected) # 1->1 optimization: a single recursive type-1 child, no raw signatures, no duplicates. @@ -152,16 +152,16 @@ def main(): hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) ensure_well_formed_input_data(type1_data_buf, bytecode_hash_domsep, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - slice_hash_with_iv(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) + slice_hash(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) bytecode_claims = Array(2) bytecode_claims[0] = type1_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[1] = recursion(inner_pub_mem, bytecode_hash_domsep) reduce_bytecode_claims(bytecode_claims, 2, bytecode_claim_output) - slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) + slice_hash(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) return # General path - computed_pubkeys_hash = slice_hash_with_iv_dynamic_unroll(all_pubkeys, n_sigs, log2_ceil(MAX_N_SIGS)) + computed_pubkeys_hash = slice_hash_dynamic_unroll(all_pubkeys, n_sigs, log2_ceil(MAX_N_SIGS)) copy_8(computed_pubkeys_hash, pubkeys_hash_expected) # Buffer for partition verification @@ -193,7 +193,8 @@ def main(): counter += 1 pk0 = all_pubkeys + idx0 * PUB_KEY_SIZE running_hash: Mut = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, pk0, running_hash) + iv = build_iv(n_sub * PUB_KEY_SIZE) + poseidon16_compress(iv, pk0, running_hash) for j in dynamic_unroll(1, n_sub, log2_ceil(MAX_N_SIGS)): idx = sub_indices_arr[j] @@ -218,7 +219,7 @@ def main(): hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) ensure_well_formed_input_data(type1_data_buf, bytecode_hash_domsep, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - slice_hash_with_iv(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) + slice_hash(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) bytecode_claims[2 * rec_idx] = type1_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[2 * rec_idx + 1] = recursion(inner_pub_mem, bytecode_hash_domsep) @@ -234,7 +235,7 @@ def main(): else: reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output) - slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) + slice_hash(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) return @@ -244,7 +245,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou claim_ptr = bytecode_claims[i] for k in unroll(BYTECODE_CLAIM_SIZE, BYTECODE_CLAIM_SIZE_PADDED): assert claim_ptr[k] == 0 - claim_hash = slice_hash(claim_ptr, BYTECODE_CLAIM_SIZE_PADDED / DIGEST_LEN) + claim_hash = slice_hash_ret(claim_ptr, BYTECODE_CLAIM_SIZE_PADDED / DIGEST_LEN) new_hash = Array(DIGEST_LEN) poseidon16_compress(bytecode_claims_hash, claim_hash, new_hash) bytecode_claims_hash = new_hash diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index 639d45127..fe3d6e7ce 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -559,7 +559,7 @@ def whir_1_merkle_step_and_pow(v, state_in, path_chunk, state_out, power_shift): @inline -def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): +def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks, leaf_iv): nibbles = Array(6) hint_decompose_bits_merkle_whir(nibbles, a, 4) @@ -580,7 +580,7 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): leaf_data = Array(num_chunks * DIGEST_LEN) hint_witness("merkle_leaf", leaf_data) - leaf_hash = slice_hash_rtl(leaf_data, num_chunks) + leaf_hash = slice_hash_rtl(leaf_data, num_chunks, leaf_iv) merkle_path = Array(domain_size * DIGEST_LEN) hint_witness("merkle_path", merkle_path) diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index 42886a92d..3124f2534 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -276,8 +276,11 @@ def decompose_and_verify_merkle_batch_with_height( def decompose_and_verify_merkle_batch_const( num_queries, sampled, root, height: Const, num_chunks: Const, circle_values, merkle_leaves ): + leaf_iv = build_iv(num_chunks * DIGEST_LEN) for i in range(0, num_queries): - merkle_leaves[i], circle_values[i] = decompose_and_verify_merkle_query(sampled[i], height, root, num_chunks) + merkle_leaves[i], circle_values[i] = decompose_and_verify_merkle_query( + sampled[i], height, root, num_chunks, leaf_iv + ) return diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index f8c09e226..dcd3fa85a 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -37,33 +37,17 @@ pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) - poseidon16_compress(input) } -/// If `use_iv` is false, the length of the slice must be constant (not malleable). -pub fn poseidon_compress_slice(data: &[KoalaBear], use_iv: bool) -> [KoalaBear; 8] { +/// Absorbs `data` in rate-mode chunks of 8, starting from the IV `[data.len(), 0, ..., 0]`. +pub fn poseidon_compress_slice(data: &[KoalaBear]) -> [KoalaBear; 8] { assert!(!data.is_empty()); assert!(data.len().is_multiple_of(8)); - if use_iv { - let mut hash = [KoalaBear::default(); 8]; - for chunk in data.chunks(8) { - let mut block = [KoalaBear::default(); 16]; - block[..8].copy_from_slice(&hash); - block[8..8 + chunk.len()].copy_from_slice(chunk); - hash = poseidon16_compress(block); - } - hash - } else { - let len = data.len(); - if len <= 16 { - let mut padded = [KoalaBear::default(); 16]; - padded[..len].copy_from_slice(data); - return poseidon16_compress(padded); - } - let mut hash = poseidon16_compress(data[0..16].try_into().unwrap()); - for chunk in data[16..].chunks(8) { - let mut block = [KoalaBear::default(); 16]; - block[..8].copy_from_slice(&hash); - block[8..8 + chunk.len()].copy_from_slice(chunk); - hash = poseidon16_compress(block); - } - hash + let mut hash = [KoalaBear::default(); 8]; + hash[0] = KoalaBear::from_usize(data.len()); + for chunk in data.chunks(8) { + let mut block = [KoalaBear::default(); 16]; + block[..8].copy_from_slice(&hash); + block[8..].copy_from_slice(chunk); + hash = poseidon16_compress(block); } + hash } diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..43446714a 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,6 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; +use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; @@ -63,22 +64,20 @@ fn build_merkle_tree_koalabear( ) -> RoundMerkleTree { let perm = default_koalabear_poseidon1_16(); let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; - let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - &perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [PFPacking; 16] = - std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( - &perm, - &leaf, - &packed_state, - effective_base_width, - ) - } else { - first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width) - }; + let iv_first = KoalaBear::from_usize(full_base_width); + let scalar_state = symetric::precompute_zero_suffix_state::( + &perm, + iv_first, + n_zero_suffix_rate_chunks, + ); + let packed_state: [PFPacking; 16] = + std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); + let first_layer = first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( + &perm, + &leaf, + &packed_state, + effective_base_width, + ); let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); WhirMerkleTree { leaf, @@ -159,7 +158,7 @@ pub struct WhirMerkleTree { full_leaf_base_width: usize, } -impl, const DIGEST_ELEMS: usize> +impl, const DIGEST_ELEMS: usize> WhirMerkleTree { #[instrument(name = "build merkle tree", skip_all)] @@ -174,21 +173,19 @@ impl, const DIGEST_ELEMS: Perm: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { let n_zero_suffix_rate_chunks = (full_leaf_base_width - effective_base_width) / RATE; - let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [P; WIDTH] = std::array::from_fn(|i| P::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::( - perm, - &leaf, - &packed_state, - effective_base_width, - ) - } else { - first_digest_layer::(perm, &leaf, full_leaf_base_width) - }; + let iv_first = F::from_usize(full_leaf_base_width); + let scalar_state = symetric::precompute_zero_suffix_state::( + perm, + iv_first, + n_zero_suffix_rate_chunks, + ); + let packed_state: [P; WIDTH] = std::array::from_fn(|i| P::from_fn(|_| scalar_state[i])); + let first_layer = first_digest_layer_with_initial_state::( + perm, + &leaf, + &packed_state, + effective_base_width, + ); let tree = symetric::merkle::MerkleTree::from_first_layer::(perm, first_layer); Self { leaf, @@ -212,42 +209,6 @@ impl, const DIGEST_ELEMS: } #[instrument(name = "first digest layer", level = "debug", skip_all)] -fn first_digest_layer( - perm: &Perm, - matrix: &M, - full_width: usize, -) -> Vec<[P::Value; DIGEST_ELEMS]> -where - P: PackedValue + Default, - P::Value: Default + Copy, - Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, - M: Matrix, -{ - let width = P::WIDTH; - let height = matrix.height(); - assert!(height.is_multiple_of(width)); - let matrix_width = matrix.width(); - let n_trailing_zeros = full_width - matrix_width; - - let mut digests = unsafe { uninitialized_vec(height) }; - - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, matrix_width, n_trailing_zeros); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); - - digests -} - -#[instrument(skip_all)] fn first_digest_layer_with_initial_state( perm: &Perm, matrix: &M,