diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 978330001..64d3733f5 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -881,6 +881,105 @@ fn eval_eq_with_packed_output, const INITIALIZED } } +#[inline] +fn eval_eq_with_packed_output_dual>( + eval_a: &[EF], + eval_b: &[EF], + out: &mut [EF::ExtensionPacking], + scalar_a: EF::ExtensionPacking, + scalar_b: EF::ExtensionPacking, +) { + debug_assert_eq!(eval_a.len(), eval_b.len()); + debug_assert_eq!(out.len(), 1 << eval_a.len()); + + match eval_a.len() { + 0 => { + out[0] = scalar_a + scalar_b; + } + 1 => { + let [a0, a1] = eval_eq_1(eval_a, scalar_a); + let [b0, b1] = eval_eq_1(eval_b, scalar_b); + out[0] = a0 + b0; + out[1] = a1 + b1; + } + 2 => { + let eq_a = eval_eq_2(eval_a, scalar_a); + let eq_b = eval_eq_2(eval_b, scalar_b); + for i in 0..4 { + out[i] = eq_a[i] + eq_b[i]; + } + } + 3 => { + let eq_a = eval_eq_3(eval_a, scalar_a); + let eq_b = eval_eq_3(eval_b, scalar_b); + for i in 0..8 { + out[i] = eq_a[i] + eq_b[i]; + } + } + _ => { + let (low, high) = out.split_at_mut(out.len() / 2); + let sa1 = scalar_a * eval_a[0]; + let sa0 = scalar_a - sa1; + let sb1 = scalar_b * eval_b[0]; + let sb0 = scalar_b - sb1; + eval_eq_with_packed_output_dual::(&eval_a[1..], &eval_b[1..], low, sa0, sb0); + eval_eq_with_packed_output_dual::(&eval_a[1..], &eval_b[1..], high, sa1, sb1); + } + } +} + +pub fn compute_eval_eq_packed_dual( + eval_a: &[EF], + eval_b: &[EF], + out: &mut [EF::ExtensionPacking], + scalar_a: EF, + scalar_b: EF, +) where + EF: ExtensionField>, +{ + let packing_width = packing_width::(); + let log_packing_width = log2_strict_usize(packing_width); + + assert_eq!(eval_a.len(), eval_b.len()); + assert!(log_packing_width <= eval_a.len()); + assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width)); + + if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS { + let mut output_no_packing = EF::zero_vec(1 << eval_a.len()); + eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a); + eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b); + out.par_iter_mut() + .zip(output_no_packing.par_chunks_exact(packing_width)) + .for_each(|(out_elem, chunk)| { + *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); + }); + } else { + let eval_len_min_packing = eval_a.len() - log_packing_width; + + let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); + let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); + let out_chunk_size = out.len() / NUM_THREADS_PADDED; + + parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a); + fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a); + + parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b); + fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b); + + out.par_chunks_exact_mut(out_chunk_size) + .enumerate() + .for_each(|(i, out_chunk)| { + eval_eq_with_packed_output_dual::, EF>( + &eval_a[LOG_NUM_THREADS..eval_len_min_packing], + &eval_b[LOG_NUM_THREADS..eval_len_min_packing], + out_chunk, + parallel_buffer_a[i], + parallel_buffer_b[i], + ); + }); + } +} + /// Computes the equality polynomial evaluations via a simple recursive algorithm. /// /// Unlike [`eval_eq_basic`], this function makes heavy use of packed values to speed up computations. @@ -968,10 +1067,19 @@ fn base_eval_eq_packed_with_packed_output( F: Field, EF: ExtensionField, { + // Ensure that the output buffer size is correct: + // It should be of size `2^n`, where `n` is the number of variables. + let width = F::Packing::WIDTH; + let log_packing_width = log2_strict_usize(width); debug_assert_eq!(out.len(), 1 << eval_points.len()); + debug_assert!(log_packing_width <= eval_points.len()); match eval_points.len() { - 0 => unreachable!(), + 0 => { + debug_assert_eq!(F::Packing::WIDTH, 1); + let base_vals = F::Packing::pack_slice(eq_evals.as_slice()); + scale_and_add_pf::(out, base_vals, packed_scalar); + } 1 => { let eq_evaluations = eval_eq_1(eval_points, eq_evals); scale_and_add_pf::(out, eq_evaluations.as_slice(), packed_scalar); @@ -1248,4 +1356,28 @@ mod tests { } } } + + #[test] + fn test_compute_eval_eq_packed_dual() { + let packing_width = ::Packing::WIDTH; + let log_packing_width = log2_strict_usize(packing_width); + let mut rng = StdRng::seed_from_u64(42); + + for n_vars in log_packing_width..22 { + let eval_a: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let eval_b: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let scalar_a: EF = rng.random(); + let scalar_b: EF = rng.random(); + + let packed_len = 1 << (n_vars - log_packing_width); + let mut out_dual = EFPacking::::zero_vec(packed_len); + compute_eval_eq_packed_dual::(&eval_a, &eval_b, &mut out_dual, scalar_a, scalar_b); + + let mut out_separate = EFPacking::::zero_vec(packed_len); + compute_eval_eq_packed::(&eval_a, &mut out_separate, scalar_a); + compute_eval_eq_packed::(&eval_b, &mut out_separate, scalar_b); + + assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars); + } + } } diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index ecce379fb..2828af039 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -45,11 +45,7 @@ pub fn run_product_sumcheck>>( assert!(n_rounds >= 1); let first_sumcheck_poly = match (pol_a, pol_b) { (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => { - if EF::DIMENSION == 5 { - compute_product_sumcheck_polynomial_base_ext_packed::<5, _, _, _, EF>(evals, weights, sum) - } else { - unimplemented!() - } + compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::::to_ext_iter([e]).collect()) } (MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => { compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::::to_ext_iter([e]).collect()) diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 21629e3d7..bd068fde0 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -89,7 +89,7 @@ fn compile_main_program_self_referential() -> Bytecode { if actual_log_size == log_size_guess { return bytecode; } - println!( + eprintln!( "Wrong guess at `compile_main_program_self_referential` (log_size {log_size_guess}->{actual_log_size})" ); log_size_guess = actual_log_size; diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 8f45a3494..27afd58f7 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -291,25 +291,18 @@ pub(super) fn run_phase2_sumcheck>>( mut sum: EF, mut mmf: EF, ) -> (Vec, [EF; 4]) { + let eq_prefix_init = &remaining_eq[..remaining_eq.len().saturating_sub(1)]; + let mut eq_table = eval_eq(eq_prefix_init); + for _round in 0..remaining_eq.len() { let eq_alpha = *remaining_eq.last().unwrap(); - let eq_prefix = &remaining_eq[..remaining_eq.len() - 1]; - let eq_table = eval_eq(eq_prefix); let active_l = num_l.len(); let active_r = num_r.len(); let active_pairs = active_l.div_ceil(2); let fully_active = active_r / 2; - let pair = |arr: &[EF], idx: usize, pad: EF| { - ( - arr.get(idx).copied().unwrap_or(pad), - arr.get(idx + 1).copied().unwrap_or(pad), - ) - }; - - let mut acc = RoundCoeffs::::zero(); - for j in 0..active_pairs { + let term = |j: usize| -> RoundCoeffs { let coeffs = if j < fully_active { pair_coeffs::( (num_l[2 * j], num_l[2 * j + 1]), @@ -318,16 +311,32 @@ pub(super) fn run_phase2_sumcheck>>( (den_r[2 * j], den_r[2 * j + 1]), ) } else { + let get_pair = |arr: &[EF], idx: usize, pad: EF| { + ( + arr.get(idx).copied().unwrap_or(pad), + arr.get(idx + 1).copied().unwrap_or(pad), + ) + }; pair_coeffs::( - pair(&num_l, 2 * j, EF::ZERO), - pair(&num_r, 2 * j, EF::ZERO), - pair(&den_l, 2 * j, EF::ONE), - pair(&den_r, 2 * j, EF::ONE), + get_pair(&num_l, 2 * j, EF::ZERO), + get_pair(&num_r, 2 * j, EF::ZERO), + get_pair(&den_l, 2 * j, EF::ONE), + get_pair(&den_r, 2 * j, EF::ONE), ) }; - acc += coeffs * eq_table[j]; - } + coeffs * eq_table[j] + }; + + let acc: RoundCoeffs = if active_pairs > PARALLEL_THRESHOLD { + (0..active_pairs) + .into_par_iter() + .map(term) + .reduce(RoundCoeffs::zero, Add::add) + } else { + (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) + }; + let eq_prefix = &remaining_eq[..remaining_eq.len() - 1]; let padding_sum = alpha * mle_of_zeros_then_ones(active_pairs, eq_prefix); let bare = build_bare_from_coeffs( @@ -349,6 +358,16 @@ pub(super) fn run_phase2_sumcheck>>( den_l = fold_normal_with_padding(&den_l, r, EF::ONE); den_r = fold_normal_with_padding(&den_r, r, EF::ONE); + let new_eq_len = eq_table.len() / 2; + if new_eq_len > 0 { + let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; + eq_table = if new_eq_len >= PARALLEL_THRESHOLD { + (0..new_eq_len).into_par_iter().map(fold_eq).collect() + } else { + (0..new_eq_len).map(fold_eq).collect() + }; + } + q_natural.push(r); remaining_eq.pop(); } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index f9634c918..6636b77c7 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -522,12 +522,41 @@ where let num_variables = statements[0].total_num_variables; assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - let mut combined_weights = EFPacking::::zero_vec(1 << (num_variables - packing_log_width::())); + let out_len = 1 << (num_variables - packing_log_width::()); + let is_full = |s: &SparseStatement| { + !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables + }; + + let mut combined_weights: Vec>; let mut combined_sum = EF::ZERO; let mut gamma_pow = EF::ONE; - for smt in statements { + let start_idx = match statements { + [a, b, ..] if is_full(a) && is_full(b) => { + combined_weights = unsafe { uninitialized_vec(out_len) }; + let sa = gamma_pow; + let sb = gamma_pow * gamma; + combined_sum = a.values[0].value * sa + b.values[0].value * sb; + gamma_pow = sb * gamma; + compute_eval_eq_packed_dual::(&a.point.0, &b.point.0, &mut combined_weights, sa, sb); + 2 + } + [a, ..] if is_full(a) => { + combined_weights = unsafe { uninitialized_vec(out_len) }; + let sa = gamma_pow; + combined_sum = a.values[0].value * sa; + gamma_pow *= gamma; + compute_eval_eq_packed::(&a.point.0, &mut combined_weights, sa); + 1 + } + _ => { + combined_weights = EFPacking::::zero_vec(out_len); + 0 + } + }; + + for smt in &statements[start_idx..] { if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { for evaluation in &smt.values { compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow);