Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion crates/backend/poly/src/eq_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,105 @@ fn eval_eq_with_packed_output<F: Field, EF: ExtensionField<F>, const INITIALIZED
}
}

#[inline]
fn eval_eq_with_packed_output_dual<F: Field, EF: ExtensionField<F>>(
eval_a: &[EF],
eval_b: &[EF],
out: &mut [EF::ExtensionPacking],
scalar_a: EF::ExtensionPacking,
scalar_b: EF::ExtensionPacking,
) {
debug_assert_eq!(eval_a.len(), eval_b.len());
debug_assert_eq!(out.len(), 1 << eval_a.len());

match eval_a.len() {
0 => {
out[0] = scalar_a + scalar_b;
}
1 => {
let [a0, a1] = eval_eq_1(eval_a, scalar_a);
let [b0, b1] = eval_eq_1(eval_b, scalar_b);
out[0] = a0 + b0;
out[1] = a1 + b1;
}
2 => {
let eq_a = eval_eq_2(eval_a, scalar_a);
let eq_b = eval_eq_2(eval_b, scalar_b);
for i in 0..4 {
out[i] = eq_a[i] + eq_b[i];
}
}
3 => {
let eq_a = eval_eq_3(eval_a, scalar_a);
let eq_b = eval_eq_3(eval_b, scalar_b);
for i in 0..8 {
out[i] = eq_a[i] + eq_b[i];
}
}
_ => {
let (low, high) = out.split_at_mut(out.len() / 2);
let sa1 = scalar_a * eval_a[0];
let sa0 = scalar_a - sa1;
let sb1 = scalar_b * eval_b[0];
let sb0 = scalar_b - sb1;
eval_eq_with_packed_output_dual::<F, EF>(&eval_a[1..], &eval_b[1..], low, sa0, sb0);
eval_eq_with_packed_output_dual::<F, EF>(&eval_a[1..], &eval_b[1..], high, sa1, sb1);
}
}
}

pub fn compute_eval_eq_packed_dual<EF>(
eval_a: &[EF],
eval_b: &[EF],
out: &mut [EF::ExtensionPacking],
scalar_a: EF,
scalar_b: EF,
) where
EF: ExtensionField<PF<EF>>,
{
let packing_width = packing_width::<EF>();
let log_packing_width = log2_strict_usize(packing_width);

assert_eq!(eval_a.len(), eval_b.len());
assert!(log_packing_width <= eval_a.len());
assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width));

if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS {
let mut output_no_packing = EF::zero_vec(1 << eval_a.len());
eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a);
eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b);
out.par_iter_mut()
.zip(output_no_packing.par_chunks_exact(packing_width))
.for_each(|(out_elem, chunk)| {
*out_elem = EF::ExtensionPacking::from_ext_slice(chunk);
});
} else {
let eval_len_min_packing = eval_a.len() - log_packing_width;

let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED);
let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED);
let out_chunk_size = out.len() / NUM_THREADS_PADDED;

parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a);
fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a);

parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b);
fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b);

out.par_chunks_exact_mut(out_chunk_size)
.enumerate()
.for_each(|(i, out_chunk)| {
eval_eq_with_packed_output_dual::<PF<EF>, EF>(
&eval_a[LOG_NUM_THREADS..eval_len_min_packing],
&eval_b[LOG_NUM_THREADS..eval_len_min_packing],
out_chunk,
parallel_buffer_a[i],
parallel_buffer_b[i],
);
});
}
}

/// Computes the equality polynomial evaluations via a simple recursive algorithm.
///
/// Unlike [`eval_eq_basic`], this function makes heavy use of packed values to speed up computations.
Expand Down Expand Up @@ -968,10 +1067,19 @@ fn base_eval_eq_packed_with_packed_output<F, EF, const INITIALIZED: bool>(
F: Field,
EF: ExtensionField<F>,
{
// Ensure that the output buffer size is correct:
// It should be of size `2^n`, where `n` is the number of variables.
let width = F::Packing::WIDTH;
let log_packing_width = log2_strict_usize(width);
debug_assert_eq!(out.len(), 1 << eval_points.len());
debug_assert!(log_packing_width <= eval_points.len());

match eval_points.len() {
0 => unreachable!(),
0 => {
debug_assert_eq!(F::Packing::WIDTH, 1);
let base_vals = F::Packing::pack_slice(eq_evals.as_slice());
scale_and_add_pf::<F, EF, INITIALIZED>(out, base_vals, packed_scalar);
}
1 => {
let eq_evaluations = eval_eq_1(eval_points, eq_evals);
scale_and_add_pf::<F, EF, INITIALIZED>(out, eq_evaluations.as_slice(), packed_scalar);
Expand Down Expand Up @@ -1248,4 +1356,28 @@ mod tests {
}
}
}

#[test]
fn test_compute_eval_eq_packed_dual() {
let packing_width = <F as Field>::Packing::WIDTH;
let log_packing_width = log2_strict_usize(packing_width);
let mut rng = StdRng::seed_from_u64(42);

for n_vars in log_packing_width..22 {
let eval_a: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let eval_b: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let scalar_a: EF = rng.random();
let scalar_b: EF = rng.random();

let packed_len = 1 << (n_vars - log_packing_width);
let mut out_dual = EFPacking::<EF>::zero_vec(packed_len);
compute_eval_eq_packed_dual::<EF>(&eval_a, &eval_b, &mut out_dual, scalar_a, scalar_b);

let mut out_separate = EFPacking::<EF>::zero_vec(packed_len);
compute_eval_eq_packed::<EF, false>(&eval_a, &mut out_separate, scalar_a);
compute_eval_eq_packed::<EF, true>(&eval_b, &mut out_separate, scalar_b);

assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars);
}
}
}
6 changes: 1 addition & 5 deletions crates/backend/sumcheck/src/product_computation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ pub fn run_product_sumcheck<EF: ExtensionField<PF<EF>>>(
assert!(n_rounds >= 1);
let first_sumcheck_poly = match (pol_a, pol_b) {
(MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => {
if EF::DIMENSION == 5 {
compute_product_sumcheck_polynomial_base_ext_packed::<5, _, _, _, EF>(evals, weights, sum)
} else {
unimplemented!()
}
compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::<EF>::to_ext_iter([e]).collect())
}
(MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => {
compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::<EF>::to_ext_iter([e]).collect())
Expand Down
2 changes: 1 addition & 1 deletion crates/rec_aggregation/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn compile_main_program_self_referential() -> Bytecode {
if actual_log_size == log_size_guess {
return bytecode;
}
println!(
eprintln!(
"Wrong guess at `compile_main_program_self_referential` (log_size {log_size_guess}->{actual_log_size})"
);
log_size_guess = actual_log_size;
Expand Down
53 changes: 36 additions & 17 deletions crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,25 +291,18 @@ pub(super) fn run_phase2_sumcheck<EF: ExtensionField<PF<EF>>>(
mut sum: EF,
mut mmf: EF,
) -> (Vec<EF>, [EF; 4]) {
let eq_prefix_init = &remaining_eq[..remaining_eq.len().saturating_sub(1)];
let mut eq_table = eval_eq(eq_prefix_init);

for _round in 0..remaining_eq.len() {
let eq_alpha = *remaining_eq.last().unwrap();
let eq_prefix = &remaining_eq[..remaining_eq.len() - 1];
let eq_table = eval_eq(eq_prefix);

let active_l = num_l.len();
let active_r = num_r.len();
let active_pairs = active_l.div_ceil(2);
let fully_active = active_r / 2;

let pair = |arr: &[EF], idx: usize, pad: EF| {
(
arr.get(idx).copied().unwrap_or(pad),
arr.get(idx + 1).copied().unwrap_or(pad),
)
};

let mut acc = RoundCoeffs::<EF>::zero();
for j in 0..active_pairs {
let term = |j: usize| -> RoundCoeffs<EF> {
let coeffs = if j < fully_active {
pair_coeffs::<EF, EF>(
(num_l[2 * j], num_l[2 * j + 1]),
Expand All @@ -318,16 +311,32 @@ pub(super) fn run_phase2_sumcheck<EF: ExtensionField<PF<EF>>>(
(den_r[2 * j], den_r[2 * j + 1]),
)
} else {
let get_pair = |arr: &[EF], idx: usize, pad: EF| {
(
arr.get(idx).copied().unwrap_or(pad),
arr.get(idx + 1).copied().unwrap_or(pad),
)
};
pair_coeffs::<EF, EF>(
pair(&num_l, 2 * j, EF::ZERO),
pair(&num_r, 2 * j, EF::ZERO),
pair(&den_l, 2 * j, EF::ONE),
pair(&den_r, 2 * j, EF::ONE),
get_pair(&num_l, 2 * j, EF::ZERO),
get_pair(&num_r, 2 * j, EF::ZERO),
get_pair(&den_l, 2 * j, EF::ONE),
get_pair(&den_r, 2 * j, EF::ONE),
)
};
acc += coeffs * eq_table[j];
}
coeffs * eq_table[j]
};

let acc: RoundCoeffs<EF> = if active_pairs > PARALLEL_THRESHOLD {
(0..active_pairs)
.into_par_iter()
.map(term)
.reduce(RoundCoeffs::zero, Add::add)
} else {
(0..active_pairs).map(term).fold(RoundCoeffs::<EF>::zero(), Add::add)
};

let eq_prefix = &remaining_eq[..remaining_eq.len() - 1];
let padding_sum = alpha * mle_of_zeros_then_ones(active_pairs, eq_prefix);

let bare = build_bare_from_coeffs(
Expand All @@ -349,6 +358,16 @@ pub(super) fn run_phase2_sumcheck<EF: ExtensionField<PF<EF>>>(
den_l = fold_normal_with_padding(&den_l, r, EF::ONE);
den_r = fold_normal_with_padding(&den_r, r, EF::ONE);

let new_eq_len = eq_table.len() / 2;
if new_eq_len > 0 {
let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1];
eq_table = if new_eq_len >= PARALLEL_THRESHOLD {
(0..new_eq_len).into_par_iter().map(fold_eq).collect()
} else {
(0..new_eq_len).map(fold_eq).collect()
};
}

q_natural.push(r);
remaining_eq.pop();
}
Expand Down
33 changes: 31 additions & 2 deletions crates/whir/src/open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,41 @@ where
let num_variables = statements[0].total_num_variables;
assert!(statements.iter().all(|e| e.total_num_variables == num_variables));

let mut combined_weights = EFPacking::<EF>::zero_vec(1 << (num_variables - packing_log_width::<EF>()));
let out_len = 1 << (num_variables - packing_log_width::<EF>());

let is_full = |s: &SparseStatement<EF>| {
!s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables
};

let mut combined_weights: Vec<EFPacking<EF>>;
let mut combined_sum = EF::ZERO;
let mut gamma_pow = EF::ONE;

for smt in statements {
let start_idx = match statements {
[a, b, ..] if is_full(a) && is_full(b) => {
combined_weights = unsafe { uninitialized_vec(out_len) };
let sa = gamma_pow;
let sb = gamma_pow * gamma;
combined_sum = a.values[0].value * sa + b.values[0].value * sb;
gamma_pow = sb * gamma;
compute_eval_eq_packed_dual::<EF>(&a.point.0, &b.point.0, &mut combined_weights, sa, sb);
2
}
[a, ..] if is_full(a) => {
combined_weights = unsafe { uninitialized_vec(out_len) };
let sa = gamma_pow;
combined_sum = a.values[0].value * sa;
gamma_pow *= gamma;
compute_eval_eq_packed::<EF, false>(&a.point.0, &mut combined_weights, sa);
1
}
_ => {
combined_weights = EFPacking::<EF>::zero_vec(out_len);
0
}
};

for smt in &statements[start_idx..] {
if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::<EF>()) {
for evaluation in &smt.values {
compute_sparse_eval_eq_packed::<EF>(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow);
Expand Down
Loading