Skip to content

Commit 97fea7c

Browse files
perf(whir_zk): output-pruned NTT for f̂ opens
Replace full Reed-Solomon re-encode at the two [[f̂]] open sites (ood_stir_and_rounds, gamma_check) with an output-pruned NTT that materialises only the queried codeword rows. The full `(num_cols × codeword_length)` codeword matrix is never resident: peak memory at the IRS-coeff bottleneck drops by a factor of `codeword_length / in_domain_samples` (≈ 4000× at m=20, k=127), and the per-encode flop count drops from O(N log N) to O(N + k log N). Algorithm: Sorensen-Burrus radix-2 DIT, walking the butterfly DAG backwards from the query set to mark only the cone of butterflies contributing to the requested outputs. Reuses the existing roots-of- unity cache. Reference: Sorensen & Burrus, "Efficient computation of the DFT with only a subset of input or output points" (IEEE TSP 41, 1993). See doc comment on `NttEngine::ntt_partial`. API additions: - `NttEngine::ntt_partial` + `ntt_partial_with_plan_into` - `PartialNttPlan` (per-(size,indices) pruning plan, reusable across batched NTTs that share the same query set) - `ntt::partial_interleaved_rs_encode` (mirrors `interleaved_rs_encode` but emits only the rows at `indices`) - `irs_commit::Config::{open_from_coeffs, open_at_indices_from_coeffs}` (functionally identical transcripts to `open`/`open_at_indices`; do not require `witness.matrix` to be populated) The blinding-poly re-encode in `prove()` is left untouched (small codeword, negligible cost). Tests: - Randomised property tests vs full NTT across sizes 4..2^15, sparse and dense query subsets, zero-padded M<N inputs, and edge cases (empty, singletons, repeated indices, size=1). - `partial_interleaved_rs_encode` byte-identity against `interleaved_rs_encode` + row extraction across four shapes spanning the regimes used in whir_zk (depth 1 vs 8, single vs multi-poly, rate-1/4 blowup). - All 155 existing whir tests still pass; fixed the pre-existing `test_rejects_g_claim_forgery_via_rho` to mirror the production open path (re-encode blinding_poly_witness before `prove_blinding_polynomial`; use new partial-encode opens for f̂).
1 parent fc3e614 commit 97fea7c

5 files changed

Lines changed: 516 additions & 58 deletions

File tree

src/algebra/ntt/cooley_tukey.rs

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,164 @@ impl<F: Field> NttEngine<F> {
356356
size => self.ntt_recurse(values, roots, size),
357357
}
358358
}
359+
360+
/// Output-pruned NTT (Sorensen-Burrus, radix-2 DIT).
361+
///
362+
/// Computes the size-`size` NTT of `values` (zero-padded to `size` if
363+
/// shorter) and returns the outputs at positions `indices`, in input
364+
/// order. Output `j` equals the full NTT at position `indices[j]`.
365+
///
366+
/// Walks the butterfly DAG backwards from `indices` to mark only the
367+
/// cone of butterflies that contribute to the queried outputs, then
368+
/// runs only the marked butterflies on the forward pass. Cost is
369+
/// `O(size + indices.len() * log(size))` field operations, vs
370+
/// `O(size * log(size))` for a full NTT.
371+
///
372+
/// `size` must be a power of two.
373+
pub fn ntt_partial(&self, values: &[F], size: usize, indices: &[usize]) -> Vec<F> {
374+
let plan = PartialNttPlan::new(size, indices);
375+
let mut out = vec![F::ZERO; indices.len()];
376+
self.ntt_partial_with_plan_into(values, &plan, &mut out, 1);
377+
out
378+
}
379+
380+
/// Run a pruned NTT using a precomputed plan and write outputs into
381+
/// `out` at stride `stride` (so `out[j * stride]` holds the result for
382+
/// `plan.indices[j]`). When `stride == 1`, output is contiguous.
383+
///
384+
/// Sharing a single plan across many NTTs with the same `(size, indices)`
385+
/// avoids re-running the O(size · log size) mask construction per call.
386+
pub fn ntt_partial_with_plan_into(
387+
&self,
388+
values: &[F],
389+
plan: &PartialNttPlan,
390+
out: &mut [F],
391+
stride: usize,
392+
) {
393+
let size = plan.size;
394+
let indices = &plan.indices;
395+
assert!(values.len() <= size, "input longer than NTT size");
396+
if indices.is_empty() {
397+
return;
398+
}
399+
assert!(
400+
out.len() >= (indices.len() - 1) * stride + 1,
401+
"output buffer too small for stride"
402+
);
403+
if size == 1 {
404+
let v = values.first().copied().unwrap_or(F::ZERO);
405+
for j in 0..indices.len() {
406+
out[j * stride] = v;
407+
}
408+
return;
409+
}
410+
411+
let log_n = size.trailing_zeros() as usize;
412+
let roots = self.roots_table(size);
413+
414+
// Load bit-reversed input into work buffer, gated by mask[0].
415+
let mut work = vec![F::ZERO; size];
416+
let shift = (usize::BITS as usize) - log_n;
417+
for (j, &c) in values.iter().enumerate() {
418+
let rev = j.reverse_bits() >> shift;
419+
if plan.mask[0][rev] {
420+
work[rev] = c;
421+
}
422+
}
423+
424+
// Forward DIT, skipping butterflies with no needed outputs.
425+
// The shared roots table may hold roots at a larger order than `size`;
426+
// `roots[k * twiddle_step]` retrieves ω_m^k regardless.
427+
for stage in 1..=log_n {
428+
let m = 1usize << stage;
429+
let half = m >> 1;
430+
let twiddle_step = roots.len() / m;
431+
let cur = &plan.mask[stage];
432+
let mut base = 0;
433+
while base < size {
434+
for k in 0..half {
435+
let a = base + k;
436+
let b = a + half;
437+
if cur[a] || cur[b] {
438+
let w = roots[k * twiddle_step];
439+
let t = work[b] * w;
440+
let u = work[a];
441+
work[a] = u + t;
442+
work[b] = u - t;
443+
}
444+
}
445+
base += m;
446+
}
447+
}
448+
449+
for (j, &i) in indices.iter().enumerate() {
450+
out[j * stride] = work[i];
451+
}
452+
}
453+
}
454+
455+
/// Pruning plan for an output-pruned NTT.
456+
///
457+
/// Holds the queried output indices and the precomputed per-stage
458+
/// "needed-position" masks used by [`NttEngine::ntt_partial_with_plan_into`].
459+
/// Construct once per `(size, indices)` and reuse across multiple NTTs of
460+
/// the same shape (e.g. all polynomials in an interleaved batch).
461+
#[derive(Debug, Clone)]
462+
pub struct PartialNttPlan {
463+
size: usize,
464+
indices: Vec<usize>,
465+
/// `mask[stage][p]` is true iff position `p` after `stage` DIT stages
466+
/// must be correct for the final outputs. `mask[log_n]` mirrors
467+
/// `indices`; `mask[0]` selects the bit-reversed input positions that
468+
/// must be loaded.
469+
mask: Vec<Vec<bool>>,
470+
}
471+
472+
impl PartialNttPlan {
473+
pub fn new(size: usize, indices: &[usize]) -> Self {
474+
assert!(size.is_power_of_two(), "size must be a power of two");
475+
assert!(
476+
indices.iter().all(|&i| i < size),
477+
"query index out of range"
478+
);
479+
let log_n = size.trailing_zeros() as usize;
480+
let mut mask: Vec<Vec<bool>> = vec![vec![false; size]; log_n + 1];
481+
for &i in indices {
482+
mask[log_n][i] = true;
483+
}
484+
for stage in (1..=log_n).rev() {
485+
let m = 1usize << stage;
486+
let half = m >> 1;
487+
let (lo, hi) = mask.split_at_mut(stage);
488+
let cur = &hi[0];
489+
let prev = &mut lo[stage - 1];
490+
let mut base = 0;
491+
while base < size {
492+
for k in 0..half {
493+
let a = base + k;
494+
let b = a + half;
495+
if cur[a] || cur[b] {
496+
prev[a] = true;
497+
prev[b] = true;
498+
}
499+
}
500+
base += m;
501+
}
502+
}
503+
Self {
504+
size,
505+
indices: indices.to_vec(),
506+
mask,
507+
}
508+
}
509+
510+
pub fn size(&self) -> usize {
511+
self.size
512+
}
513+
514+
pub fn indices(&self) -> &[usize] {
515+
&self.indices
516+
}
359517
}
360518

361519
/// Applies twiddle factors to a slice of field elements in-place.
@@ -963,4 +1121,93 @@ mod tests {
9631121

9641122
assert_eq!(values_ntt, expected_values);
9651123
}
1124+
1125+
#[test]
1126+
fn test_ntt_partial_matches_full() {
1127+
use ark_std::{rand::Rng, UniformRand};
1128+
1129+
let engine = NttEngine::<Field64>::new_from_fftfield();
1130+
let mut rng = ark_std::test_rng();
1131+
1132+
for &size in &[4usize, 16, 64, 256, 1024, 1 << 15] {
1133+
for _ in 0..8 {
1134+
// Full NTT reference.
1135+
let coeffs: Vec<_> = (0..size).map(|_| Field64::rand(&mut rng)).collect();
1136+
let mut full = coeffs.clone();
1137+
engine.ntt_batch(&mut full, size);
1138+
1139+
// Random subset of varying size (cover dense + sparse).
1140+
let k = rng.gen_range(1..=size.min(64));
1141+
let mut perm: Vec<usize> = (0..size).collect();
1142+
for i in (1..size).rev() {
1143+
perm.swap(i, rng.gen_range(0..=i));
1144+
}
1145+
let indices: Vec<usize> = perm.into_iter().take(k).collect();
1146+
1147+
let partial = engine.ntt_partial(&coeffs, size, &indices);
1148+
assert_eq!(partial.len(), indices.len());
1149+
for (j, &idx) in indices.iter().enumerate() {
1150+
assert_eq!(partial[j], full[idx], "size={size} idx={idx}");
1151+
}
1152+
}
1153+
}
1154+
}
1155+
1156+
#[test]
1157+
fn test_ntt_partial_zero_padded_input() {
1158+
// M < N: input is zero-padded. Partial NTT must agree with full NTT
1159+
// computed over the zero-padded coefficient vector.
1160+
use ark_std::UniformRand;
1161+
1162+
let engine = NttEngine::<Field64>::new_from_fftfield();
1163+
let mut rng = ark_std::test_rng();
1164+
1165+
for (m, size) in [(1usize, 4), (4, 16), (256, 1024), (1 << 13, 1 << 15)] {
1166+
let coeffs: Vec<_> = (0..m).map(|_| Field64::rand(&mut rng)).collect();
1167+
let mut padded = coeffs.clone();
1168+
padded.resize(size, Field64::ZERO);
1169+
engine.ntt_batch(&mut padded, size);
1170+
1171+
let stride = (size / 8).max(1);
1172+
let indices: Vec<usize> = (0..size).step_by(stride).take(8).collect();
1173+
let partial = engine.ntt_partial(&coeffs, size, &indices);
1174+
for (j, &idx) in indices.iter().enumerate() {
1175+
assert_eq!(partial[j], padded[idx], "m={m} size={size} idx={idx}");
1176+
}
1177+
}
1178+
}
1179+
1180+
#[test]
1181+
fn test_ntt_partial_edge_cases() {
1182+
use ark_std::UniformRand;
1183+
1184+
let engine = NttEngine::<Field64>::new_from_fftfield();
1185+
let mut rng = ark_std::test_rng();
1186+
1187+
// Empty index set.
1188+
let coeffs: Vec<_> = (0..16).map(|_| Field64::rand(&mut rng)).collect();
1189+
let out = engine.ntt_partial(&coeffs, 16, &[]);
1190+
assert!(out.is_empty());
1191+
1192+
// Singleton at position 0 and position N-1.
1193+
let coeffs: Vec<_> = (0..64).map(|_| Field64::rand(&mut rng)).collect();
1194+
let mut full = coeffs.clone();
1195+
engine.ntt_batch(&mut full, 64);
1196+
for idx in [0usize, 1, 31, 32, 63] {
1197+
let out = engine.ntt_partial(&coeffs, 64, &[idx]);
1198+
assert_eq!(out, vec![full[idx]], "idx={idx}");
1199+
}
1200+
1201+
// Repeated indices: each occurrence must yield the matching output.
1202+
let indices = vec![5usize, 5, 17, 5, 17];
1203+
let out = engine.ntt_partial(&coeffs, 64, &indices);
1204+
for (j, &idx) in indices.iter().enumerate() {
1205+
assert_eq!(out[j], full[idx]);
1206+
}
1207+
1208+
// size = 1: any indices must all return values[0].
1209+
let single = vec![Field64::from(42)];
1210+
let out = engine.ntt_partial(&single, 1, &[0, 0, 0]);
1211+
assert_eq!(out, vec![Field64::from(42); 3]);
1212+
}
9661213
}

src/algebra/ntt/mod.rs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use static_assertions::assert_obj_safe;
2121
#[cfg(feature = "tracing")]
2222
use tracing::instrument;
2323

24-
use self::matrix::MatrixMut;
24+
use self::{cooley_tukey::NttEngine, matrix::MatrixMut};
2525
pub use self::{
26-
cooley_tukey::{generator, intt, intt_batch, ntt, ntt_batch},
26+
cooley_tukey::{generator, intt, intt_batch, ntt, ntt_batch, PartialNttPlan},
2727
transpose::transpose,
2828
wavelet::{inverse_wavelet_transform, wavelet_transform},
2929
};
@@ -93,6 +93,54 @@ pub fn interleaved_rs_encode<F: 'static>(
9393
engine.interleaved_encode(interleaved_coeffs, codeword_length, interleaving_depth)
9494
}
9595

96+
/// Partial Reed-Solomon encode that materialises only the rows at `indices`.
97+
///
98+
/// Equivalent to taking [`interleaved_rs_encode`]'s output (a row-major
99+
/// `(codeword_length, num_polys * interleaving_depth)` matrix) and
100+
/// extracting the rows whose row index is in `indices`. Output layout is
101+
/// row-major `(indices.len(), num_polys * interleaving_depth)`, byte-exact
102+
/// against the full encode.
103+
///
104+
/// Uses an output-pruned NTT (see [`PartialNttPlan`]) so peak memory and
105+
/// flop count are both proportional to `indices.len()`, not
106+
/// `codeword_length`. The pruning plan is built once for the index set and
107+
/// reused across every polynomial × interleaving slot.
108+
#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(coeffs, indices), fields(size = coeffs.len(), k = indices.len())))]
109+
pub fn partial_interleaved_rs_encode<F: FftField>(
110+
coeffs: &[&[F]],
111+
codeword_length: usize,
112+
interleaving_depth: usize,
113+
indices: &[usize],
114+
) -> Vec<F> {
115+
if coeffs.is_empty() || indices.is_empty() {
116+
return Vec::new();
117+
}
118+
let poly_size = coeffs[0].len();
119+
for poly in coeffs {
120+
assert_eq!(poly.len(), poly_size);
121+
}
122+
assert!(poly_size.is_multiple_of(interleaving_depth));
123+
let message_length = poly_size / interleaving_depth;
124+
assert!(codeword_length.is_multiple_of(message_length));
125+
126+
let num_polys = coeffs.len();
127+
let num_cols = num_polys * interleaving_depth;
128+
let k = indices.len();
129+
130+
let engine = NttEngine::<F>::new_from_cache();
131+
let plan = PartialNttPlan::new(codeword_length, indices);
132+
133+
let mut out = vec![F::ZERO; k * num_cols];
134+
for (poly_idx, poly) in coeffs.iter().enumerate() {
135+
for slot_idx in 0..interleaving_depth {
136+
let col = poly_idx * interleaving_depth + slot_idx;
137+
let block = &poly[slot_idx * message_length..(slot_idx + 1) * message_length];
138+
engine.ntt_partial_with_plan_into(block, &plan, &mut out[col..], num_cols);
139+
}
140+
}
141+
out
142+
}
143+
96144
///
97145
/// RS encode coefficients grouped in `interleaving_depth` contiguous blocks
98146
/// at the rate 1/`expansion`, then interleave the evaluations per point.
@@ -350,4 +398,58 @@ mod tests {
350398
interleaved_rs_encode(&[poly.as_slice()], codeword_length, 1 << folding_factor);
351399
assert_eq!(expected, interleaved_ntt);
352400
}
401+
402+
#[test]
403+
fn test_partial_interleaved_rs_encode_matches_full() {
404+
use ark_std::{rand::Rng, UniformRand};
405+
406+
let mut rng = ark_std::test_rng();
407+
408+
// Span several (num_polys, interleaving_depth, M, N) shapes covering
409+
// the regimes that actually appear in whir_zk (single witness with
410+
// depth 8, multi-witness with depth 1, M = N/4 blowup).
411+
let cases = [
412+
(1usize, 1usize, 64usize, 256usize),
413+
(1, 8, 16, 64),
414+
(2, 4, 32, 128),
415+
(1, 8, 1 << 10, 1 << 12),
416+
];
417+
418+
for (num_polys, interleaving_depth, message_length, codeword_length) in cases {
419+
let poly_size = message_length * interleaving_depth;
420+
let polys: Vec<Vec<Field64>> = (0..num_polys)
421+
.map(|_| (0..poly_size).map(|_| Field64::rand(&mut rng)).collect())
422+
.collect();
423+
let poly_slices: Vec<&[Field64]> = polys.iter().map(Vec::as_slice).collect();
424+
425+
let full = interleaved_rs_encode(&poly_slices, codeword_length, interleaving_depth);
426+
let num_cols = num_polys * interleaving_depth;
427+
assert_eq!(full.len(), codeword_length * num_cols);
428+
429+
// Random subset including 0, last, and a sprinkling in between.
430+
let k = rng.gen_range(1..=codeword_length.min(16));
431+
let mut perm: Vec<usize> = (0..codeword_length).collect();
432+
for i in (1..codeword_length).rev() {
433+
perm.swap(i, rng.gen_range(0..=i));
434+
}
435+
let indices: Vec<usize> = perm.into_iter().take(k).collect();
436+
437+
let partial = partial_interleaved_rs_encode(
438+
&poly_slices,
439+
codeword_length,
440+
interleaving_depth,
441+
&indices,
442+
);
443+
assert_eq!(partial.len(), k * num_cols);
444+
445+
for (row, &idx) in indices.iter().enumerate() {
446+
let full_row = &full[idx * num_cols..(idx + 1) * num_cols];
447+
let partial_row = &partial[row * num_cols..(row + 1) * num_cols];
448+
assert_eq!(
449+
partial_row, full_row,
450+
"shape=({num_polys},{interleaving_depth},{message_length},{codeword_length}) row idx={idx}"
451+
);
452+
}
453+
}
454+
}
353455
}

0 commit comments

Comments
 (0)