From 2e9ff54c3e7cb87f51415ebe833c763b0aa92295 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 02:26:19 +0400 Subject: [PATCH 1/6] attempt to replicate parts of https://github.com/leanEthereum/leanMultisig/pull/234 Co-Authored-By: Barnadrot --- crates/backend/poly/src/eq_mle.rs | 136 +++++++++++++++++- .../sumcheck/src/product_computation.rs | 8 +- crates/rec_aggregation/src/compilation.rs | 2 +- crates/sub_protocols/src/logup.rs | 8 +- .../src/quotient_gkr/sumcheck_utils.rs | 104 ++++++++++---- crates/whir/src/open.rs | 46 +++++- 6 files changed, 259 insertions(+), 45 deletions(-) diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 978330001..64aef66ec 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -881,6 +881,105 @@ fn eval_eq_with_packed_output, const INITIALIZED } } +#[inline] +fn eval_eq_with_packed_output_dual>( + eval_a: &[EF], + eval_b: &[EF], + out: &mut [EF::ExtensionPacking], + scalar_a: EF::ExtensionPacking, + scalar_b: EF::ExtensionPacking, +) { + debug_assert_eq!(eval_a.len(), eval_b.len()); + debug_assert_eq!(out.len(), 1 << eval_a.len()); + + match eval_a.len() { + 0 => { + out[0] = scalar_a + scalar_b; + } + 1 => { + let [a0, a1] = eval_eq_1(eval_a, scalar_a); + let [b0, b1] = eval_eq_1(eval_b, scalar_b); + out[0] = a0 + b0; + out[1] = a1 + b1; + } + 2 => { + let eq_a = eval_eq_2(eval_a, scalar_a); + let eq_b = eval_eq_2(eval_b, scalar_b); + for i in 0..4 { + out[i] = eq_a[i] + eq_b[i]; + } + } + 3 => { + let eq_a = eval_eq_3(eval_a, scalar_a); + let eq_b = eval_eq_3(eval_b, scalar_b); + for i in 0..8 { + out[i] = eq_a[i] + eq_b[i]; + } + } + _ => { + let (low, high) = out.split_at_mut(out.len() / 2); + let sa1 = scalar_a * eval_a[0]; + let sa0 = scalar_a - sa1; + let sb1 = scalar_b * eval_b[0]; + let sb0 = scalar_b - sb1; + eval_eq_with_packed_output_dual::(&eval_a[1..], &eval_b[1..], low, sa0, sb0); + eval_eq_with_packed_output_dual::(&eval_a[1..], &eval_b[1..], high, sa1, sb1); + } + } +} + +pub fn compute_eval_eq_packed_dual( + eval_a: &[EF], + eval_b: &[EF], + out: &mut [EF::ExtensionPacking], + scalar_a: EF, + scalar_b: EF, +) where + EF: ExtensionField>, +{ + let packing_width = packing_width::(); + let log_packing_width = log2_strict_usize(packing_width); + + assert_eq!(eval_a.len(), eval_b.len()); + assert!(log_packing_width <= eval_a.len()); + assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width)); + + if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS { + let mut output_no_packing = EF::zero_vec(1 << eval_a.len()); + eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a); + eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b); + out.par_iter_mut() + .zip(output_no_packing.par_chunks_exact(packing_width)) + .for_each(|(out_elem, chunk)| { + *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); + }); + } else { + let eval_len_min_packing = eval_a.len() - log_packing_width; + + let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); + let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); + let out_chunk_size = out.len() / NUM_THREADS_PADDED; + + parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a); + fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a); + + parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b); + fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b); + + out.par_chunks_exact_mut(out_chunk_size) + .enumerate() + .for_each(|(i, out_chunk)| { + eval_eq_with_packed_output_dual::, EF>( + &eval_a[LOG_NUM_THREADS..eval_len_min_packing], + &eval_b[LOG_NUM_THREADS..eval_len_min_packing], + out_chunk, + parallel_buffer_a[i], + parallel_buffer_b[i], + ); + }); + } +} + /// Computes the equality polynomial evaluations via a simple recursive algorithm. /// /// Unlike [`eval_eq_basic`], this function makes heavy use of packed values to speed up computations. @@ -968,10 +1067,19 @@ fn base_eval_eq_packed_with_packed_output( F: Field, EF: ExtensionField, { + // Ensure that the output buffer size is correct: + // It should be of size `2^n`, where `n` is the number of variables. + let width = F::Packing::WIDTH; + let log_packing_width = log2_strict_usize(width); debug_assert_eq!(out.len(), 1 << eval_points.len()); + debug_assert!(log_packing_width <= eval_points.len()); match eval_points.len() { - 0 => unreachable!(), + 0 => { + debug_assert_eq!(F::Packing::WIDTH, 1); + let base_vals = F::Packing::pack_slice(eq_evals.as_slice()); + scale_and_add_pf::(out, base_vals, packed_scalar); + } 1 => { let eq_evaluations = eval_eq_1(eval_points, eq_evals); scale_and_add_pf::(out, eq_evaluations.as_slice(), packed_scalar); @@ -1248,4 +1356,28 @@ mod tests { } } } -} + + #[test] + fn test_compute_eval_eq_packed_dual() { + let packing_width = ::Packing::WIDTH; + let log_packing_width = log2_strict_usize(packing_width); + let mut rng = StdRng::seed_from_u64(42); + + for n_vars in log_packing_width..22 { + let eval_a: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let eval_b: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let scalar_a: EF = rng.random(); + let scalar_b: EF = rng.random(); + + let packed_len = 1 << (n_vars - log_packing_width); + let mut out_dual = EFPacking::::zero_vec(packed_len); + compute_eval_eq_packed_dual::(&eval_a, &eval_b, &mut out_dual, scalar_a, scalar_b); + + let mut out_separate = EFPacking::::zero_vec(packed_len); + compute_eval_eq_packed::(&eval_a, &mut out_separate, scalar_a); + compute_eval_eq_packed::(&eval_b, &mut out_separate, scalar_b); + + assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars); + } + } +} \ No newline at end of file diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index ecce379fb..0ad7bed2c 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -45,11 +45,7 @@ pub fn run_product_sumcheck>>( assert!(n_rounds >= 1); let first_sumcheck_poly = match (pol_a, pol_b) { (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => { - if EF::DIMENSION == 5 { - compute_product_sumcheck_polynomial_base_ext_packed::<5, _, _, _, EF>(evals, weights, sum) - } else { - unimplemented!() - } + compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::::to_ext_iter([e]).collect()) } (MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => { compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::::to_ext_iter([e]).collect()) @@ -312,4 +308,4 @@ where let constant = y_0 * x_0; let quadratic = (y_1 - y_0) * (x_1 - x_0); (constant, quadratic) -} +} \ No newline at end of file diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 4fa257843..8ddd41677 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -89,7 +89,7 @@ fn compile_main_program_self_referential() -> Bytecode { if actual_log_size == log_size_guess { return bytecode; } - println!( + eprintln!( "Wrong guess at `compile_main_program_self_referential` (log_size {log_size_guess}->{actual_log_size})" ); log_size_guess = actual_log_size; diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 55af0a320..25fb569e4 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -56,10 +56,8 @@ pub fn prove_generic_logup( let memory_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_MEMORY_DOMAINSEP)); let bytecode_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_BYTECODE_DOMAINSEP)); - let min_section_log = log_bytecode.min(tables_log_heights_sorted.last().unwrap().1); - if min_section_log < ENDIANNESS_PIVOT_GKR { - tracing::info!("TODO: suboptimal GKR pivot (could be improved)."); - } + let log_bytecode_section = log_bytecode.max(tables_log_heights_sorted[0].1); + let min_section_log = log_bytecode_section.min(tables_log_heights_sorted.last().unwrap().1); let pivot = ENDIANNESS_PIVOT_GKR.min(min_section_log); let chunk_size = 1usize << pivot; let chunk_shift = usize::BITS as usize - pivot; @@ -533,4 +531,4 @@ where Build: Fn(usize) -> EFPacking + Sync, { dst.par_iter_mut().enumerate().for_each(|(p, slot)| *slot = build(p)); -} +} \ No newline at end of file diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 8f45a3494..14cf6f529 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -291,43 +291,77 @@ pub(super) fn run_phase2_sumcheck>>( mut sum: EF, mut mmf: EF, ) -> (Vec, [EF; 4]) { + let eq_prefix_init = &remaining_eq[..remaining_eq.len().saturating_sub(1)]; + let mut eq_table = eval_eq(eq_prefix_init); + for _round in 0..remaining_eq.len() { let eq_alpha = *remaining_eq.last().unwrap(); - let eq_prefix = &remaining_eq[..remaining_eq.len() - 1]; - let eq_table = eval_eq(eq_prefix); let active_l = num_l.len(); let active_r = num_r.len(); let active_pairs = active_l.div_ceil(2); let fully_active = active_r / 2; - let pair = |arr: &[EF], idx: usize, pad: EF| { - ( - arr.get(idx).copied().unwrap_or(pad), - arr.get(idx + 1).copied().unwrap_or(pad), - ) + let acc = if active_pairs >= PARALLEL_THRESHOLD { + let fa = fully_active; + (0..active_pairs) + .into_par_iter() + .fold(RoundCoeffs::zero, |mut acc, j| { + let coeffs = if j < fa { + pair_coeffs::( + (num_l[2 * j], num_l[2 * j + 1]), + (num_r[2 * j], num_r[2 * j + 1]), + (den_l[2 * j], den_l[2 * j + 1]), + (den_r[2 * j], den_r[2 * j + 1]), + ) + } else { + let get_pair = |arr: &[EF], idx: usize, pad: EF| { + ( + arr.get(idx).copied().unwrap_or(pad), + arr.get(idx + 1).copied().unwrap_or(pad), + ) + }; + pair_coeffs::( + get_pair(&num_l, 2 * j, EF::ZERO), + get_pair(&num_r, 2 * j, EF::ZERO), + get_pair(&den_l, 2 * j, EF::ONE), + get_pair(&den_r, 2 * j, EF::ONE), + ) + }; + acc += coeffs * eq_table[j]; + acc + }) + .reduce(RoundCoeffs::zero, Add::add) + } else { + let mut acc = RoundCoeffs::::zero(); + for j in 0..active_pairs { + let coeffs = if j < fully_active { + pair_coeffs::( + (num_l[2 * j], num_l[2 * j + 1]), + (num_r[2 * j], num_r[2 * j + 1]), + (den_l[2 * j], den_l[2 * j + 1]), + (den_r[2 * j], den_r[2 * j + 1]), + ) + } else { + let get_pair = |arr: &[EF], idx: usize, pad: EF| { + ( + arr.get(idx).copied().unwrap_or(pad), + arr.get(idx + 1).copied().unwrap_or(pad), + ) + }; + pair_coeffs::( + get_pair(&num_l, 2 * j, EF::ZERO), + get_pair(&num_r, 2 * j, EF::ZERO), + get_pair(&den_l, 2 * j, EF::ONE), + get_pair(&den_r, 2 * j, EF::ONE), + ) + }; + acc += coeffs * eq_table[j]; + } + acc }; - let mut acc = RoundCoeffs::::zero(); - for j in 0..active_pairs { - let coeffs = if j < fully_active { - pair_coeffs::( - (num_l[2 * j], num_l[2 * j + 1]), - (num_r[2 * j], num_r[2 * j + 1]), - (den_l[2 * j], den_l[2 * j + 1]), - (den_r[2 * j], den_r[2 * j + 1]), - ) - } else { - pair_coeffs::( - pair(&num_l, 2 * j, EF::ZERO), - pair(&num_r, 2 * j, EF::ZERO), - pair(&den_l, 2 * j, EF::ONE), - pair(&den_r, 2 * j, EF::ONE), - ) - }; - acc += coeffs * eq_table[j]; - } - + let eq_prefix = &remaining_eq[..remaining_eq.len() - 1]; let padding_sum = alpha * mle_of_zeros_then_ones(active_pairs, eq_prefix); let bare = build_bare_from_coeffs( @@ -349,6 +383,20 @@ pub(super) fn run_phase2_sumcheck>>( den_l = fold_normal_with_padding(&den_l, r, EF::ONE); den_r = fold_normal_with_padding(&den_r, r, EF::ONE); + let new_eq_len = eq_table.len() / 2; + if new_eq_len > 0 { + let mut new_eq = unsafe { uninitialized_vec(new_eq_len) }; + let fold_eq = |(i, slot): (usize, &mut EF)| { + *slot = eq_table[2 * i] + eq_table[2 * i + 1]; + }; + if new_eq_len >= PARALLEL_THRESHOLD { + new_eq.par_iter_mut().enumerate().for_each(fold_eq); + } else { + new_eq.iter_mut().enumerate().for_each(fold_eq); + } + eq_table = new_eq; + } + q_natural.push(r); remaining_eq.pop(); } @@ -500,4 +548,4 @@ fn build_bare_from_coeffs>>( let h1_mmf = (sum - (EF::ONE - eq_alpha) * c0_mmf) / eq_alpha; let c1_mmf = h1_mmf - c0_mmf - c2_mmf; DensePolynomial::new(vec![c0_mmf, c1_mmf, c2_mmf]) -} +} \ No newline at end of file diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index f9634c918..38209111c 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -522,12 +522,52 @@ where let num_variables = statements[0].total_num_variables; assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - let mut combined_weights = EFPacking::::zero_vec(1 << (num_variables - packing_log_width::())); + let out_len = 1 << (num_variables - packing_log_width::()); + let first = &statements[0]; + let first_is_full_initializer = !first.is_next + && first.values.len() == 1 + && first.values[0].selector == 0 + && first.inner_num_variables() == num_variables; + + let mut combined_weights: Vec>; let mut combined_sum = EF::ZERO; let mut gamma_pow = EF::ONE; + let start_idx; + + if first_is_full_initializer { + combined_weights = unsafe { uninitialized_vec(out_len) }; + let first_scalar = gamma_pow; + combined_sum += first.values[0].value * gamma_pow; + gamma_pow *= gamma; + + let second = statements.get(1); + let second_is_full_domain = second.is_some_and(|s| { + !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables + }); - for smt in statements { + if second_is_full_domain { + let second = &statements[1]; + compute_eval_eq_packed_dual::( + &first.point.0, + &second.point.0, + &mut combined_weights, + first_scalar, + gamma_pow, + ); + combined_sum += second.values[0].value * gamma_pow; + gamma_pow *= gamma; + start_idx = 2; + } else { + compute_eval_eq_packed::(&first.point.0, &mut combined_weights, first_scalar); + start_idx = 1; + } + } else { + combined_weights = EFPacking::::zero_vec(out_len); + start_idx = 0; + } + + for smt in &statements[start_idx..] { if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { for evaluation in &smt.values { compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow); @@ -581,4 +621,4 @@ where } (combined_weights, combined_sum) -} +} \ No newline at end of file From f7cf9095f5d4f4f5072fe8469f9370f04691b132 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 13:54:55 +0200 Subject: [PATCH 2/6] undo GKR pivot change --- crates/sub_protocols/src/logup.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 25fb569e4..55af0a320 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -56,8 +56,10 @@ pub fn prove_generic_logup( let memory_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_MEMORY_DOMAINSEP)); let bytecode_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_BYTECODE_DOMAINSEP)); - let log_bytecode_section = log_bytecode.max(tables_log_heights_sorted[0].1); - let min_section_log = log_bytecode_section.min(tables_log_heights_sorted.last().unwrap().1); + let min_section_log = log_bytecode.min(tables_log_heights_sorted.last().unwrap().1); + if min_section_log < ENDIANNESS_PIVOT_GKR { + tracing::info!("TODO: suboptimal GKR pivot (could be improved)."); + } let pivot = ENDIANNESS_PIVOT_GKR.min(min_section_log); let chunk_size = 1usize << pivot; let chunk_shift = usize::BITS as usize - pivot; @@ -531,4 +533,4 @@ where Build: Fn(usize) -> EFPacking + Sync, { dst.par_iter_mut().enumerate().for_each(|(p, slot)| *slot = build(p)); -} \ No newline at end of file +} From ef12ae1541e0c820bd5a9ca80a856261a050a89e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 13:57:56 +0200 Subject: [PATCH 3/6] simplify sumcheck_utils --- .../src/quotient_gkr/sumcheck_utils.rs | 96 +++++++------------ 1 file changed, 32 insertions(+), 64 deletions(-) diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 14cf6f529..5319454e8 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -302,63 +302,35 @@ pub(super) fn run_phase2_sumcheck>>( let active_pairs = active_l.div_ceil(2); let fully_active = active_r / 2; - let acc = if active_pairs >= PARALLEL_THRESHOLD { - let fa = fully_active; - (0..active_pairs) - .into_par_iter() - .fold(RoundCoeffs::zero, |mut acc, j| { - let coeffs = if j < fa { - pair_coeffs::( - (num_l[2 * j], num_l[2 * j + 1]), - (num_r[2 * j], num_r[2 * j + 1]), - (den_l[2 * j], den_l[2 * j + 1]), - (den_r[2 * j], den_r[2 * j + 1]), - ) - } else { - let get_pair = |arr: &[EF], idx: usize, pad: EF| { - ( - arr.get(idx).copied().unwrap_or(pad), - arr.get(idx + 1).copied().unwrap_or(pad), - ) - }; - pair_coeffs::( - get_pair(&num_l, 2 * j, EF::ZERO), - get_pair(&num_r, 2 * j, EF::ZERO), - get_pair(&den_l, 2 * j, EF::ONE), - get_pair(&den_r, 2 * j, EF::ONE), - ) - }; - acc += coeffs * eq_table[j]; - acc - }) - .reduce(RoundCoeffs::zero, Add::add) - } else { - let mut acc = RoundCoeffs::::zero(); - for j in 0..active_pairs { - let coeffs = if j < fully_active { - pair_coeffs::( - (num_l[2 * j], num_l[2 * j + 1]), - (num_r[2 * j], num_r[2 * j + 1]), - (den_l[2 * j], den_l[2 * j + 1]), - (den_r[2 * j], den_r[2 * j + 1]), - ) - } else { - let get_pair = |arr: &[EF], idx: usize, pad: EF| { - ( - arr.get(idx).copied().unwrap_or(pad), - arr.get(idx + 1).copied().unwrap_or(pad), - ) - }; - pair_coeffs::( - get_pair(&num_l, 2 * j, EF::ZERO), - get_pair(&num_r, 2 * j, EF::ZERO), - get_pair(&den_l, 2 * j, EF::ONE), - get_pair(&den_r, 2 * j, EF::ONE), + let term = |j: usize| -> RoundCoeffs { + let coeffs = if j < fully_active { + pair_coeffs::( + (num_l[2 * j], num_l[2 * j + 1]), + (num_r[2 * j], num_r[2 * j + 1]), + (den_l[2 * j], den_l[2 * j + 1]), + (den_r[2 * j], den_r[2 * j + 1]), + ) + } else { + let get_pair = |arr: &[EF], idx: usize, pad: EF| { + ( + arr.get(idx).copied().unwrap_or(pad), + arr.get(idx + 1).copied().unwrap_or(pad), ) }; - acc += coeffs * eq_table[j]; - } - acc + pair_coeffs::( + get_pair(&num_l, 2 * j, EF::ZERO), + get_pair(&num_r, 2 * j, EF::ZERO), + get_pair(&den_l, 2 * j, EF::ONE), + get_pair(&den_r, 2 * j, EF::ONE), + ) + }; + coeffs * eq_table[j] + }; + + let acc: RoundCoeffs = if active_pairs >= PARALLEL_THRESHOLD { + (0..active_pairs).into_par_iter().map(term).reduce(RoundCoeffs::zero, Add::add) + } else { + (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) }; let eq_prefix = &remaining_eq[..remaining_eq.len() - 1]; @@ -385,16 +357,12 @@ pub(super) fn run_phase2_sumcheck>>( let new_eq_len = eq_table.len() / 2; if new_eq_len > 0 { - let mut new_eq = unsafe { uninitialized_vec(new_eq_len) }; - let fold_eq = |(i, slot): (usize, &mut EF)| { - *slot = eq_table[2 * i] + eq_table[2 * i + 1]; - }; - if new_eq_len >= PARALLEL_THRESHOLD { - new_eq.par_iter_mut().enumerate().for_each(fold_eq); + let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; + eq_table = if new_eq_len >= PARALLEL_THRESHOLD { + (0..new_eq_len).into_par_iter().map(fold_eq).collect() } else { - new_eq.iter_mut().enumerate().for_each(fold_eq); - } - eq_table = new_eq; + (0..new_eq_len).map(fold_eq).collect() + }; } q_natural.push(r); From c3b1947494fc8badff5fb3aac36c434d8cfbd270 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 14:14:01 +0200 Subject: [PATCH 4/6] simplify open.rs --- crates/whir/src/open.rs | 59 +++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 38209111c..73bff2ad2 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -524,48 +524,37 @@ where let out_len = 1 << (num_variables - packing_log_width::()); - let first = &statements[0]; - let first_is_full_initializer = !first.is_next - && first.values.len() == 1 - && first.values[0].selector == 0 - && first.inner_num_variables() == num_variables; + let is_full = |s: &SparseStatement| { + !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables + }; let mut combined_weights: Vec>; let mut combined_sum = EF::ZERO; let mut gamma_pow = EF::ONE; - let start_idx; - if first_is_full_initializer { - combined_weights = unsafe { uninitialized_vec(out_len) }; - let first_scalar = gamma_pow; - combined_sum += first.values[0].value * gamma_pow; - gamma_pow *= gamma; - - let second = statements.get(1); - let second_is_full_domain = second.is_some_and(|s| { - !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables - }); - - if second_is_full_domain { - let second = &statements[1]; - compute_eval_eq_packed_dual::( - &first.point.0, - &second.point.0, - &mut combined_weights, - first_scalar, - gamma_pow, - ); - combined_sum += second.values[0].value * gamma_pow; + let start_idx = match statements { + [a, b, ..] if is_full(a) && is_full(b) => { + combined_weights = unsafe { uninitialized_vec(out_len) }; + let sa = gamma_pow; + let sb = gamma_pow * gamma; + combined_sum = a.values[0].value * sa + b.values[0].value * sb; + gamma_pow = sb * gamma; + compute_eval_eq_packed_dual::(&a.point.0, &b.point.0, &mut combined_weights, sa, sb); + 2 + } + [a, ..] if is_full(a) => { + combined_weights = unsafe { uninitialized_vec(out_len) }; + let sa = gamma_pow; + combined_sum = a.values[0].value * sa; gamma_pow *= gamma; - start_idx = 2; - } else { - compute_eval_eq_packed::(&first.point.0, &mut combined_weights, first_scalar); - start_idx = 1; + compute_eval_eq_packed::(&a.point.0, &mut combined_weights, sa); + 1 } - } else { - combined_weights = EFPacking::::zero_vec(out_len); - start_idx = 0; - } + _ => { + combined_weights = EFPacking::::zero_vec(out_len); + 0 + } + }; for smt in &statements[start_idx..] { if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { From 2260dacdc866d473ffb9022b1d552741dcc6368d Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 14:30:55 +0200 Subject: [PATCH 5/6] fmt --- crates/backend/poly/src/eq_mle.rs | 2 +- crates/backend/sumcheck/src/product_computation.rs | 2 +- crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs | 7 +++++-- crates/whir/src/open.rs | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 64aef66ec..64d3733f5 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -1380,4 +1380,4 @@ mod tests { assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars); } } -} \ No newline at end of file +} diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index 0ad7bed2c..2828af039 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -308,4 +308,4 @@ where let constant = y_0 * x_0; let quadratic = (y_1 - y_0) * (x_1 - x_0); (constant, quadratic) -} \ No newline at end of file +} diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 5319454e8..22751493d 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -328,7 +328,10 @@ pub(super) fn run_phase2_sumcheck>>( }; let acc: RoundCoeffs = if active_pairs >= PARALLEL_THRESHOLD { - (0..active_pairs).into_par_iter().map(term).reduce(RoundCoeffs::zero, Add::add) + (0..active_pairs) + .into_par_iter() + .map(term) + .reduce(RoundCoeffs::zero, Add::add) } else { (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) }; @@ -516,4 +519,4 @@ fn build_bare_from_coeffs>>( let h1_mmf = (sum - (EF::ONE - eq_alpha) * c0_mmf) / eq_alpha; let c1_mmf = h1_mmf - c0_mmf - c2_mmf; DensePolynomial::new(vec![c0_mmf, c1_mmf, c2_mmf]) -} \ No newline at end of file +} diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 73bff2ad2..6636b77c7 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -610,4 +610,4 @@ where } (combined_weights, combined_sum) -} \ No newline at end of file +} From 065d7c7ac060c691b2eb61c04fa1e05c198ec91c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 26 May 2026 16:51:46 +0400 Subject: [PATCH 6/6] > instead of >= for parallele threshold in sumcheck_utils.rs (faster on M4 MAx, slightly slower on ax42u) --- crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 22751493d..27afd58f7 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -327,7 +327,7 @@ pub(super) fn run_phase2_sumcheck>>( coeffs * eq_table[j] }; - let acc: RoundCoeffs = if active_pairs >= PARALLEL_THRESHOLD { + let acc: RoundCoeffs = if active_pairs > PARALLEL_THRESHOLD { (0..active_pairs) .into_par_iter() .map(term)