Skip to content

Commit d4ad1fb

Browse files
perf(whir_zk): parallelise partial_interleaved_rs_encode batches
Each (poly_idx, slot_idx) NTT in the partial encode is independent. Switch to a batch-major intermediate (`(num_cols, k)`) populated via `par_chunks_exact_mut` and transpose to the row-major output. Brings the partial encode in line with the parallel batching the existing `ntt_batch` performs inside the full encode.
1 parent 97fea7c commit d4ad1fb

1 file changed

Lines changed: 32 additions & 8 deletions

File tree

src/algebra/ntt/mod.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use std::{
1717
};
1818

1919
use ark_ff::{FftField, Field};
20+
#[cfg(feature = "parallel")]
21+
use rayon::prelude::*;
2022
use static_assertions::assert_obj_safe;
2123
#[cfg(feature = "tracing")]
2224
use tracing::instrument;
@@ -130,15 +132,37 @@ pub fn partial_interleaved_rs_encode<F: FftField>(
130132
let engine = NttEngine::<F>::new_from_cache();
131133
let plan = PartialNttPlan::new(codeword_length, indices);
132134

133-
let mut out = vec![F::ZERO; k * num_cols];
134-
for (poly_idx, poly) in coeffs.iter().enumerate() {
135-
for slot_idx in 0..interleaving_depth {
136-
let col = poly_idx * interleaving_depth + slot_idx;
137-
let block = &poly[slot_idx * message_length..(slot_idx + 1) * message_length];
138-
engine.ntt_partial_with_plan_into(block, &plan, &mut out[col..], num_cols);
139-
}
135+
// Build the submatrix in batch-major layout (`(num_cols, k)`): each
136+
// contiguous k-chunk is one NTT's outputs. Batches are independent, so
137+
// populate in parallel across (poly_idx, slot_idx). Final transpose
138+
// converts to the row-major `(k, num_cols)` layout that
139+
// `irs_commit::open_inner_from_coeffs` expects.
140+
let mut batch_major = vec![F::ZERO; num_cols * k];
141+
142+
#[cfg(feature = "parallel")]
143+
{
144+
batch_major
145+
.par_chunks_exact_mut(k)
146+
.enumerate()
147+
.for_each(|(col, dst)| {
148+
let poly_idx = col / interleaving_depth;
149+
let slot_idx = col % interleaving_depth;
150+
let block = &coeffs[poly_idx]
151+
[slot_idx * message_length..(slot_idx + 1) * message_length];
152+
engine.ntt_partial_with_plan_into(block, &plan, dst, 1);
153+
});
154+
}
155+
#[cfg(not(feature = "parallel"))]
156+
for (col, dst) in batch_major.chunks_exact_mut(k).enumerate() {
157+
let poly_idx = col / interleaving_depth;
158+
let slot_idx = col % interleaving_depth;
159+
let block =
160+
&coeffs[poly_idx][slot_idx * message_length..(slot_idx + 1) * message_length];
161+
engine.ntt_partial_with_plan_into(block, &plan, dst, 1);
140162
}
141-
out
163+
164+
transpose(&mut batch_major, num_cols, k);
165+
batch_major
142166
}
143167

144168
///

0 commit comments

Comments
 (0)