Skip to content

Commit fc3e614

Browse files
perf(whir_zk): hold IRS coefficients, re-encode codeword on demand
The initial IRS commit witnesses (f_hat and blinding_poly) previously held their full Reed-Solomon encoded codewords resident from commit through the entire whir_zk::prove. The codeword is only consumed at open time (Merkle path generation + queried row extraction); the coefficients are smaller by the blowup factor (e.g. 4x at rate 1/4) and already retained for other protocol uses. Drop matrix immediately after commit. Re-encode transiently around each open and drop again after. Three encodes per whir_zk::prove call: one for each of f_hat's two opens (ood_stir_and_rounds, gamma_check) and one for blinding_poly's open in prove_blinding_polynomial. Measured on complete_age_check (m=20, N=5 interleaved): - peak: 805 -> 706 MB (-99 MB / -12.3%) - wall (median): 3500 -> 4220 ms (+20.6%, +720 ms) - allocs: 3.56M -> 3.61M (+50k) Combined with linear_forms drop (c183108) versus unoptimised v1: - peak: 880 -> 706 MB (-174 MB / -19.8%) Protocol-equivalent. Prove + verify roundtrip passes byte-identically. Re-encoded codeword matches the original since interleaved_rs_encode is deterministic.
1 parent c183108 commit fc3e614

2 files changed

Lines changed: 104 additions & 21 deletions

File tree

src/protocols/whir_zk/committer.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,14 @@ impl<F: FftField> Config<F> {
103103

104104
// Step 1b: Commit [[f̂]] via first WHIR instance.
105105
let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect();
106-
let f_hat_witness = self.blinded_polynomial.commit(prover_state, &f_hat_refs);
106+
let mut f_hat_witness = self.blinded_polynomial.commit(prover_state, &f_hat_refs);
107+
108+
// Drop the encoded codeword; will be re-encoded immediately before each
109+
// open in prove_blinded_polynomial (Steps 4 and 6). This keeps the
110+
// ~codeword_length × interleaving_depth field elements out of the
111+
// resident set during the prepare_and_sumcheck rounds where global peak
112+
// hits.
113+
f_hat_witness.matrix = Vec::new();
107114

108115
// Step 1c: Sample ν + 1 random ℓ-variate blinding polynomials ĝ₀..ĝ_ν.
109116
let num_blinding_polys = dims.num_g_polys();
@@ -138,10 +145,17 @@ impl<F: FftField> Config<F> {
138145
}
139146
let blinding_refs: Vec<&[F]> = blinding_vectors.iter().map(|v| v.as_slice()).collect();
140147

141-
let blinding_poly_witness = self
148+
let mut blinding_poly_witness = self
142149
.blinding_polynomial
143150
.commit(prover_state, &blinding_refs);
144151

152+
// The encoded codeword is only needed when [[M, ĝ]] is opened in
153+
// Step 7. Until then it is dead weight (held resident through all of
154+
// prove_blinded_polynomial, where global peak hits). Drop the matrix
155+
// here; the prover re-encodes from `secrets.blinding_vectors` just
156+
// before calling `blinding_polynomial.prove`.
157+
blinding_poly_witness.matrix = Vec::new();
158+
145159
Witness {
146160
f_hat_witness,
147161
blinding_poly_witness,

src/protocols/whir_zk/prover.rs

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ use crate::{
2222
embedding::Identity,
2323
geometric_sequence,
2424
linear_form::{Covector, Evaluate, LinearForm, UnivariateEvaluation},
25-
multilinear_extend, univariate_evaluate, MultilinearPoint,
25+
multilinear_extend,
26+
ntt::interleaved_rs_encode,
27+
univariate_evaluate, MultilinearPoint,
2628
},
2729
hash::Hash,
2830
protocols::{
@@ -314,17 +316,20 @@ where
314316

315317
/// Step 5: OOD/STIR queries, STIR constraint accumulation, and remaining WHIR rounds.
316318
///
317-
/// Takes ownership of `f_hat_polys` so it can be freed after OOD evaluations,
318-
/// before the memory-intensive WHIR rounds begin.
319+
/// Borrows `f_hat_polys` so it remains available for re-encoding the
320+
/// f_hat codeword in Step 6 (`gamma_check`). The codeword in
321+
/// `f_hat_witness.matrix` is re-encoded just before its open and cleared
322+
/// immediately after, to keep it out of the resident set during the
323+
/// memory-intensive sumcheck rounds.
319324
#[allow(clippy::too_many_arguments)]
320325
fn ood_stir_and_rounds(
321326
&mut self,
322327
state: &mut whir::rounds::SumcheckState<'_, F>,
323328
alpha_coeffs: &[F],
324329
rho: F,
325330
folding_randomness: MultilinearPoint<F>,
326-
f_hat_witness: &irs_commit::Witness<F, F>,
327-
f_hat_polys: Vec<Vec<F>>,
331+
f_hat_witness: &mut irs_commit::Witness<F, F>,
332+
f_hat_polys: &[Vec<F>],
328333
masking_polys: &[Vec<F>],
329334
g_polys: &[Vec<F>],
330335
) -> OodStirResult<F> {
@@ -336,11 +341,27 @@ where
336341
.irs_committer
337342
.commit(self.prover_state, &[state.vector.as_slice()]);
338343
round_config.pow.prove(self.prover_state);
344+
345+
// Re-encode f_hat codeword for the upcoming open, then drop it again.
346+
let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect();
347+
f_hat_witness.matrix = interleaved_rs_encode(
348+
&f_hat_refs,
349+
self.config
350+
.blinded_polynomial
351+
.initial_committer
352+
.codeword_length,
353+
self.config
354+
.blinded_polynomial
355+
.initial_committer
356+
.interleaving_depth,
357+
);
358+
drop(f_hat_refs);
339359
let in_domain = self
340360
.config
341361
.blinded_polynomial
342362
.initial_committer
343-
.open(self.prover_state, &[f_hat_witness]);
363+
.open(self.prover_state, &[&*f_hat_witness]);
364+
f_hat_witness.matrix = Vec::new();
344365

345366
let r_bar = folding_randomness.0;
346367
let eq_weights = compute_eq_weights(&r_bar);
@@ -385,9 +406,9 @@ where
385406
lambda_z_points.push(z);
386407
}
387408

388-
// Release f̂ data before WHIR rounds.
409+
// Release f̂_combined before WHIR rounds. f_hat_polys is borrowed
410+
// from the caller (still needed for re-encoding in gamma_check).
389411
drop(f_hat_combined);
390-
drop(f_hat_polys);
391412

392413
// --- STIR responses ---
393414
for &z in &in_domain.points {
@@ -436,24 +457,45 @@ where
436457
/// Step 6: Γ consistency check.
437458
///
438459
/// Opens [[f̂]] at Γ indices and sends blinding evaluations for each γ ∈ Γ.
460+
/// Re-encodes the f_hat codeword into `f_hat_witness.matrix` before the
461+
/// open and clears it after, mirroring the pattern in `ood_stir_and_rounds`.
439462
fn gamma_check(
440463
&mut self,
441-
f_hat_witness: &irs_commit::Witness<F, F>,
464+
f_hat_witness: &mut irs_commit::Witness<F, F>,
465+
f_hat_polys: &[Vec<F>],
442466
masking_coeffs_all: &[Vec<F>],
443467
g_i_coeffs: &[Vec<F>],
444468
gamma_points: &[F],
445469
lambda_z_points: &mut Vec<F>,
446470
) {
447471
let gamma_f_hat_indices = gamma_to_f_hat_indices(gamma_points, self.config);
448472

473+
// Re-encode f_hat codeword for the open at Γ indices.
474+
let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect();
475+
f_hat_witness.matrix = interleaved_rs_encode(
476+
&f_hat_refs,
477+
self.config
478+
.blinded_polynomial
479+
.initial_committer
480+
.codeword_length,
481+
self.config
482+
.blinded_polynomial
483+
.initial_committer
484+
.interleaving_depth,
485+
);
486+
drop(f_hat_refs);
487+
449488
// Writes [[f̂]] openings at Γ indices to the transcript.
450489
// The verifier uses these to reconstruct fold(r̄, [[f̂]])(γ).
451490
// Return value (Evaluations) is unused: the prover already knows the values.
452491
let _f_hat_openings = self
453492
.config
454493
.blinded_polynomial
455494
.initial_committer
456-
.open_at_indices(self.prover_state, &[f_hat_witness], &gamma_f_hat_indices);
495+
.open_at_indices(self.prover_state, &[&*f_hat_witness], &gamma_f_hat_indices);
496+
497+
// Drop the codeword again; nothing else in this protocol needs it.
498+
f_hat_witness.matrix = Vec::new();
457499

458500
for &gamma in gamma_points {
459501
send_blinding_evals(self.prover_state, gamma, masking_coeffs_all, g_i_coeffs);
@@ -465,16 +507,18 @@ where
465507
impl<F: FftField> Config<F> {
466508
/// Steps 2-6: Prove the blinded polynomial instance.
467509
///
468-
/// `f_hat_polys` is taken by value and freed during OOD evaluations (Step 5),
469-
/// before the memory-intensive WHIR rounds begin.
470-
/// Other witness fields are borrowed; the caller frees them before Step 7.
510+
/// `f_hat_witness.matrix` is empty on entry (cleared at commit time); it
511+
/// is re-encoded transiently around each open and cleared afterwards to
512+
/// keep the codeword out of the resident set during sumcheck rounds.
513+
/// `f_hat_polys` is borrowed (needed for re-encoding in both Step 5
514+
/// `ood_stir_and_rounds` and Step 6 `gamma_check`).
471515
#[allow(clippy::too_many_arguments)]
472516
fn prove_blinded_polynomial<H, R>(
473517
&self,
474518
prover_state: &mut ProverState<H, R>,
475519
vectors: Vec<Cow<'_, [F]>>,
476-
f_hat_witness: &irs_commit::Witness<F, F>,
477-
f_hat_polys: Vec<Vec<F>>,
520+
f_hat_witness: &mut irs_commit::Witness<F, F>,
521+
f_hat_polys: &[Vec<F>],
478522
masking_polys: &[Vec<F>],
479523
g_polys: &[Vec<F>],
480524
linear_forms: Vec<Box<dyn LinearForm<F>>>,
@@ -550,6 +594,7 @@ impl<F: FftField> Config<F> {
550594

551595
ctx.gamma_check(
552596
f_hat_witness,
597+
f_hat_polys,
553598
&masking_coeffs_all,
554599
&g_i_coeffs,
555600
&gamma_points,
@@ -670,18 +715,25 @@ impl<F: FftField> Config<F> {
670715
Hash: ProverMessage<[H::U]>,
671716
{
672717
let Witness {
673-
f_hat_witness,
674-
blinding_poly_witness,
718+
mut f_hat_witness,
719+
mut blinding_poly_witness,
675720
f_hat_polys,
676721
secrets,
677722
} = witness;
678723

679724
// Steps 2-6: blinded polynomial proof.
725+
// Both `f_hat_witness.matrix` and `blinding_poly_witness.matrix` are
726+
// empty here (cleared at commit time). The blinded prover re-encodes
727+
// f_hat transiently around each of its two opens; blinding_poly is
728+
// re-encoded just before Step 7 below. This keeps both codewords
729+
// (~codeword_length × interleaving_depth field elements each) out of
730+
// the resident set during the prepare_and_sumcheck rounds where global
731+
// peak hits.
680732
let blinded = self.prove_blinded_polynomial(
681733
prover_state,
682734
vectors,
683-
&f_hat_witness,
684-
f_hat_polys,
735+
&mut f_hat_witness,
736+
&f_hat_polys,
685737
&secrets.masking_polys,
686738
&secrets.g_polys,
687739
linear_forms,
@@ -690,6 +742,23 @@ impl<F: FftField> Config<F> {
690742

691743
// Free fields only needed during Steps 2-6, before Step 7.
692744
drop(f_hat_witness);
745+
drop(f_hat_polys);
746+
747+
// Re-encode the [[M, ĝ]] codeword, which was dropped at commit time
748+
// to keep the resident set small through Step 6.
749+
let blinding_refs: Vec<&[F]> = secrets
750+
.blinding_vectors
751+
.iter()
752+
.map(|v| v.as_slice())
753+
.collect();
754+
blinding_poly_witness.matrix = interleaved_rs_encode(
755+
&blinding_refs,
756+
self.blinding_polynomial.initial_committer.codeword_length,
757+
self.blinding_polynomial
758+
.initial_committer
759+
.interleaving_depth,
760+
);
761+
drop(blinding_refs);
693762

694763
// Step 7: batched blinding polynomial proof.
695764
self.prove_blinding_polynomial(

0 commit comments

Comments
 (0)