diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 00000000..860ad42f --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,70 @@ +name: Benchmarks + +on: + push: + branches: [main] + pull_request: + branches: [main] + +permissions: + contents: write + deployments: write + +jobs: + benchmark: + name: Criterion benchmarks + # Use self-hosted for stable hardware (no noisy-neighbor variance). + # To use self-hosted: change to ["self-hosted", "linux", "x64"]. + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: bench-${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Detect CPU features + id: cpu + run: | + if grep -q avx512ifma /proc/cpuinfo 2>/dev/null; then + echo "rustflags=-C target-feature=+avx512ifma" >> "$GITHUB_OUTPUT" + echo "label=avx512" >> "$GITHUB_OUTPUT" + echo "::notice::AVX-512 IFMA detected — SIMD benchmarks enabled" + elif [ "$(uname -m)" = "aarch64" ]; then + echo "rustflags=" >> "$GITHUB_OUTPUT" + echo "label=neon" >> "$GITHUB_OUTPUT" + echo "::notice::aarch64 detected — NEON benchmarks enabled" + else + echo "rustflags=" >> "$GITHUB_OUTPUT" + echo "label=scalar" >> "$GITHUB_OUTPUT" + echo "::notice::No SIMD target features detected — scalar benchmarks" + fi + + - name: Run benchmarks + env: + RUSTFLAGS: ${{ steps.cpu.outputs.rustflags }} + run: cargo bench --bench sumcheck -- --output-format bencher | tee output.txt + + - name: Store benchmark results + uses: benchmark-action/github-action-benchmark@v1 + with: + name: "Sumcheck Benchmarks (${{ steps.cpu.outputs.label }})" + tool: cargo + output-file-path: output.txt + # On push to main: commit results to gh-pages for trendline tracking. + # On PR: compare against main baseline and comment. + auto-push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + github-token: ${{ secrets.GITHUB_TOKEN }} + # Alert thresholds. + alert-threshold: "115%" + comment-on-alert: true + fail-on-alert: false + benchmark-data-dir-path: "dev/bench-${{ steps.cpu.outputs.label }}" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..2444a5d9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,40 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +env: + CARGO_TERM_COLOR: always + +jobs: + check: + name: Build + Test + Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ci-${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Build + run: cargo build --release + + - name: Clippy + run: cargo clippy --release -- -D warnings + + - name: Tests + run: cargo test --release diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b9bcf337..96785d77 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -50,10 +50,10 @@ jobs: - name: Run tests run: cargo test --verbose - test_with_no_default_features: + build_with_no_default_features: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Run tests - run: cargo test --verbose --no-default-features \ No newline at end of file + - name: Build without arkworks + run: cargo build --verbose --no-default-features \ No newline at end of file diff --git a/.gitignore b/.gitignore index e186dd4a..b7bfef53 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ **/lag-poly-benches/target/ .vscode .DS_Store +.claude/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 990164f8..26b76b3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,38 @@ All notable changes to this project will be documented in this file. -## [Unreleased] +## [Unreleased] — Canonical Rewrite + +Major revision: unified API, one verifier, one proof type, 7 provers, SIMD acceleration. + +### Breaking + +- **Package renamed** from `efficient-sumcheck` to `effsc`. +- **Single verifier** — `sumcheck_verify()` returns `SumcheckResult { challenges, final_claim }`. The oracle check is the caller's responsibility ([Thaler Remark 4.2](https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf)). Removed `inner_product_sumcheck_verify`, `multilinear_sumcheck_verify`, and `coefficient_sumcheck::sumcheck_verify`. +- **Single proof type** — `SumcheckProof` replaces `Sumcheck` and `ProductSumcheck`. +- **Transcript redesigned** — `send()`/`receive()`/`challenge()` replace `read()`/`write()`. +- **Legacy entry points demoted** — use `runner::sumcheck()` with a prover type. ### Added -- **Base/Extension field support**: `multilinear_sumcheck` and `inner_product_sumcheck` now take two type parameters `` — base field for evaluations, extension field for challenges. Set `EF = BF` when no extension is needed. -- `pairwise::cross_field_reduce` — parallel helper for folding `BF` evaluations with an `EF` challenge. + +- **`SumcheckProver` trait** — single extension point for all polynomial shapes. +- **7 concrete provers** — `MultilinearProver`, `InnerProductProver`, `CoefficientProver` (each with MSB + LSB variants), `GkrProver`. +- **`SumcheckField` trait** — generic field interface; blanket impl for `ark_ff::Field` behind `feature = "arkworks"`. +- **`SimdRepr` trait** — safe SIMD opt-in with `zerocopy` layout verification. +- **`runner::sumcheck()`** — single runner with partial execution and per-round hooks. +- **Eq polynomial utilities** — `eq_poly`, `eq_poly_non_binary`, O(2^v) incremental `compute_hypercube_eq_evals`. +- **Adversarial verifier tests** — corrupted proofs, wrong sums, wrong final values across all prover types. +- **`no_std` support** — core library works without `arkworks` feature. +- **SIMD** — transparent 8-wide AVX-512 IFMA, 2-wide NEON acceleration. + +### Integrations + +- [WHIR](https://github.com/WizardOfMenlo/whir) ([PR #250](https://github.com/WizardOfMenlo/whir/pull/250)) +- [WARP](https://github.com/compsec-epfl/warp) ([PR #24](https://github.com/compsec-epfl/warp/pull/24)) + +### Removed + +- **~4,500 lines of legacy code** — old `Prover` trait, `TimeProver`/`SpaceProver`/`BlendyProver`, `OrderStrategy`, `messages/`, `interpolation/`, `simd_ops`. ## [0.0.2] - 2026-02-11 diff --git a/Cargo.toml b/Cargo.toml index 7f278d32..623fa01d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,35 +1,53 @@ [package] -name = "efficient-sumcheck" +name = "effsc" version = "0.0.2" authors = ["arkworks maintainers"] description = "TODO" repository = "https://github.com/compsec-epfl/space-efficient-sumcheck" include = ["Cargo.toml", "src", "README.md", "LICENSE-APACHE", "LICENSE-MIT"] edition = "2021" +rust-version = "1.81" [dependencies] -ark-ff = "0.5.0" -ark-poly = "0.5.0" -ark-serialize = "0.5.0" -ark-std ="0.5.0" +ark-ff = { version = "0.5.0", optional = true } +ark-poly = { version = "0.5.0", optional = true } +ark-serialize = { version = "0.5.0", optional = true } +ark-std = { version = "0.5.0", optional = true } memmap2 = "0.9.5" nohash-hasher = "0.2.0" rayon = { version = "1.10", optional = true } -spongefish = { git = "https://github.com/arkworks-rs/spongefish", branch = "main", features = ["ark-ff"] } +spongefish = { git = "https://github.com/z-tech/spongefish.git", branch = "smallfp-support", features = ["ark-ff"], optional = true } +zerocopy = { version = "0.8", features = ["derive"] } [dev-dependencies] criterion = "0.8" +p3-field = "0.5" +p3-goldilocks = "0.5" [features] -default = ["parallel"] +default = ["arkworks", "parallel", "simd"] +arkworks = [ + "dep:ark-ff", + "dep:ark-poly", + "dep:ark-serialize", + "dep:ark-std", + "dep:spongefish", +] +simd = [] parallel = [ "dep:rayon", - "ark-ff/parallel", - "ark-poly/parallel", - "ark-std/parallel", + "ark-ff?/parallel", + "ark-poly?/parallel", + "ark-std?/parallel", ] [[bench]] -name = "provers" -path = "benches/provers.rs" +name = "sumcheck" +path = "benches/sumcheck.rs" harness = false + +[patch.crates-io] +ark-ff = { git = "https://github.com/arkworks-rs/algebra.git", branch = "master" } +ark-poly = { git = "https://github.com/arkworks-rs/algebra.git", branch = "master" } +ark-serialize = { git = "https://github.com/arkworks-rs/algebra.git", branch = "master" } +spongefish = { git = "https://github.com/z-tech/spongefish.git", branch = "smallfp-support" } diff --git a/README.md b/README.md index a64be6ef..f03281ce 100644 --- a/README.md +++ b/README.md @@ -1,162 +1,158 @@

Efficient Sumcheck

-Efficient, streaming capable, sumcheck with **Fiat–Shamir** support via [SpongeFish](https://github.com/arkworks-rs/spongefish). +A high-performance sumcheck library with [correctness-fuzzing](#correctness) against a verified oracle. -**Security note:** This library has not undergone a formal security audit. +- **Efficient** — transparent SIMD acceleration (8-wide AVX-512, 2-wide NEON) +- **Streaming-capable** — optional sublinear memory via sequential evaluation +- **Complete** — built-in Fiat-Shamir, partial execution, per-round hooks -## General Use +Built using [arkworks](https://github.com/arkworks-rs). Compatible with any ecosystem — see [`docs/compatibility.md`](docs/compatibility.md). Research-grade; not yet audited — see [`SECURITY.md`](SECURITY.md). -This library exposes three high-level functions: -1) [`multilinear_sumcheck`](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear_sumcheck.rs#L123), -2) [`inner_product_sumcheck`](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/inner_product_sumcheck.rs#L166), and -3) [`coefficient_sumcheck`](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/coefficient_sumcheck.rs#L17). +## Quick Start -The first two are parameterized by two field types: `BF` (base field, of the evaluations) and `EF` (extension field, of the challenges). When no extension field is needed, set `EF = BF`. +### Multilinear Sumcheck -Using [SpongeFish](https://github.com/arkworks-rs/spongefish) (or similar Fiat-Shamir interface) simply call the functions with the Spongefish transcript: +Proves $H = \displaystyle\sum_{x \in \lbrace 0,1 \rbrace^v} p(x)$ where $p$ is a multilinear polynomial. -### Multilinear Sumcheck -```math -claim = \sum_{x \in \{0,1\}^n} p(x) -``` ```rust -use efficient_sumcheck::{multilinear_sumcheck, Sumcheck}; -use efficient_sumcheck::transcript::SanityTranscript; - -let mut evals_p_01n: Vec = /* ... */; -let mut prover_state = SanityTranscript::new(&mut rng); -let sumcheck_transcript: Sumcheck = multilinear_sumcheck::( - &mut evals_p_01n, - &mut prover_state +use effsc::{noop_hook, runner::sumcheck}; +use effsc::provers::multilinear::MultilinearProver; + +let mut prover = MultilinearProver::new(evals); +let proof = sumcheck( + &mut prover, + num_vars, + &mut transcript, + noop_hook, ); ``` ### Inner Product Sumcheck -```math -claim = \sum_{x \in \{0,1\}^n} f(x) \cdot g(x) -``` + +Proves $H = \displaystyle\sum_{x \in \lbrace 0,1 \rbrace^v} f(x) \cdot g(x)$ for two multilinear polynomials. Degree-2 round polynomials. ```rust -use efficient_sumcheck::{inner_product_sumcheck, ProductSumcheck}; -use efficient_sumcheck::transcript::SanityTranscript; - -let mut evals_f_01n: Vec = /* ... */; -let mut evals_g_01n: Vec = /* ... */; -let mut prover_state = SanityTranscript::new(&mut rng); -let sumcheck_transcript: ProductSumcheck = inner_product_sumcheck::( - &mut evals_f_01n, - &mut evals_g_01n, - &mut prover_state +use effsc::{noop_hook, runner::sumcheck}; +use effsc::provers::inner_product::InnerProductProver; + +let mut prover = InnerProductProver::new(a, b); +let proof = sumcheck( + &mut prover, + num_vars, + &mut transcript, + noop_hook, ); ``` ### Coefficient Sumcheck -```math -claim = \sum_{x \in \{0,1\}^n} p(x), \quad \deg_{x_i}(p) \leq d -``` -Unlike the multilinear and inner product variants where `p` is multilinear (degree 1 in each variable, yielding degree-1 round polynomials), `coefficient_sumcheck` handles polynomials with arbitrary per-variable degree `d`, producing degree-`d` round polynomials. The user supplies a closure `compute_round_poly` that computes each round polynomial; the library handles transcript interaction and table reductions (both pairwise and tablewise) automatically. +Proves $H = \displaystyle\sum_{x \in \lbrace 0,1 \rbrace^v} p(x)$ where $\deg_{x_i}(p) \leq d$. The user implements `RoundPolyEvaluator` to define per-pair round polynomial contributions; the library handles iteration, parallelism, and reductions. ```rust -use efficient_sumcheck::coefficient_sumcheck::{coefficient_sumcheck, CoefficientSumcheck}; -use efficient_sumcheck::transcript::SanityTranscript; -use ark_poly::univariate::DensePolynomial; - -let mut tablewise: Vec>> = /* multi-column tables */; -let mut pairwise: Vec> = /* flat evaluation vectors */; -let mut transcript = SanityTranscript::new(&mut rng); - -let result: CoefficientSumcheck = coefficient_sumcheck( - |tablewise, pairwise| { - // Compute h(X) as a DensePolynomial from current table state. - // Return coefficients in ascending order: [c0, c1, ..., cd]. - DensePolynomial::from_coefficients_vec(vec![/* ... */]) - }, - &mut tablewise, - &mut pairwise, - n_rounds, - &mut transcript, +use effsc::{noop_hook, runner::sumcheck}; +use effsc::provers::coefficient::CoefficientProver; + +let mut prover = CoefficientProver::new( + &evaluator, + tablewise, + pairwise, +); +let proof = sumcheck( + &mut prover, + num_rounds, + &mut transcript, + noop_hook, ); ``` -The closure receives immutable references to the current tables; after each round the library automatically reduces all pairwise and tablewise entries by folding with the verifier challenge. - -## Examples - -### 1) WARP - Multilinear Constraint Batching +### Verification -Before integration, [WARP](https://github.com/compsec-epfl/warp) used 200+ lines of sumcheck related code including calls to SpongeFish, pair- and table-wise reductions, as well as sparse-map foldings ([PR #14](https://github.com/compsec-epfl/warp/pull/14), [PR #12](https://github.com/compsec-epfl/warp/pull/12/changes#diff-904f410986c619441fb8554f4840cb36613f2de354b41ca991d381dec78959b0L34)). - -Using Efficient Sumcheck this reduces to six lines of code and brings parallelization via Rayon (and soon vectorization via SIMD): +One verifier for any degree $d$. Returns `SumcheckResult { challenges, final_claim }` — ⚠️ the caller is responsible for the oracle check ([Thaler Remark 4.2](https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf)). ```rust -use efficient_sumcheck::{inner_product_sumcheck, batched_constraint_poly}; +use effsc::{noop_hook_verify, verifier::sumcheck_verify}; -let alpha = inner_product_sumcheck::( - &mut f, - &mut batched_constraint_poly(&dense_evals, &sparse_evals), +let result = sumcheck_verify( + claimed_sum, + degree, + num_rounds, &mut transcript, -).verifier_messages; + noop_hook_verify, +)?; + +// Standalone: compare against the prover's claimed final value. +assert_eq!(result.final_claim, proof.final_value); + +// Composed (WHIR, GKR): pass final_claim to the next layer. +next_layer_claim = result.final_claim; ``` -Here, `batched_constraint_poly` merges dense evaluation vectors (out-of-domain samples) with sparse map-represented polynomials (in-domain queries) into a single constraint polynomial, ready for the inner product sumcheck. +## Variable Ordering + +Each prover comes in two variants: + +- **MSB** (half-split) — optimal memory layout in most cases. Used by WHIR and WARP. +- **LSB** (pair-split) — optimal for streaming applications where evaluations arrive sequentially. + +| MSB | LSB | +|-----|-----| +| `MultilinearProver` | `MultilinearProverLSB` | +| `InnerProductProver` | `InnerProductProverLSB` | +| `CoefficientProver` | `CoefficientProverLSB` | +| `GkrProver` | — | + +See [`docs/design.md`](docs/design.md) for details. -### 2) WARP - Twin Constraint Batching +## Partial Execution and Hooks -[WARP](https://github.com/compsec-epfl/warp) also uses `coefficient_sumcheck` with `folding::protogalaxy::fold` to batch a codeword check and an R1CS constraint check into a single sumcheck. The codewords, witness vectors, and folding coefficients are stored as tablewise tables and the equality polynomial evaluations as a pairwise vector: +The `sumcheck()` runner supports partial execution (`num_rounds < v`) and per-round hooks for composed protocols: ```rust -use efficient_sumcheck::coefficient_sumcheck::coefficient_sumcheck; -use efficient_sumcheck::folding::protogalaxy; - -let mut tablewise = [codewords, z_vecs, alpha_vecs, beta_vecs]; -let mut pairwise = [tau_eq_evals]; - -let sc = coefficient_sumcheck( - |tw, pw| { - let (u, z, a, b) = (&tw[0], &tw[1], &tw[2], &tw[3]); - let tau = &pw[0]; - - let f = protogalaxy::fold(/* ... */, /* codeword polys */); - let p = protogalaxy::fold(/* ... */, /* constraint polys */); - let t = linear_poly(tau[0], tau[1]); - - // h(X) = (f(X) + ω·p(X)) · t(X) - (f + p * omega).naive_mul(&t) - }, - &mut tablewise, - &mut pairwise, - log_l, - &mut prover_state, +// WHIR: partial sumcheck with proof-of-work grinding +let proof = sumcheck( + &mut prover, + folding_factor, // num_rounds < v + &mut transcript, + |_, t| round_pow.prove(t), // per-round hook ); -let gamma = sc.verifier_messages; ``` -After each round `coefficient_sumcheck` reduces all four tablewise tables and the pairwise equality evaluations by folding with the verifier's challenge. +## SIMD Acceleration + +All provers transparently auto-dispatch to SIMD backends. Supported fields: + +- [x] Goldilocks ($p = 2^{64} - 2^{32} + 1$) and degree-2/3 extensions +- [ ] M31 ($p = 2^{31} - 1$) and extensions +- [ ] BabyBear ($p = 2^{31} - 2^{27} + 1$) and extensions +- [ ] KoalaBear ($p = 2^{31} - 2^{24} + 1$) and extensions + +| Backend | Width | Platform | +|---------|-------|----------| +| NEON | 2-wide | aarch64 (Apple M-series, Graviton) | +| AVX-512 IFMA | 8-wide | x86_64 (Sapphire Rapids) | -## Advanced Usage +Falls back to scalar for other fields. See [`SECURITY.md`](SECURITY.md#unsafe-code) for `unsafe` scope. -Supporting the high-level interfaces are raw implementations of sumcheck [[LFKN92](#references)] using three proving algorithms: +## Integrations -- The quasi-linear time and logarithmic space algorithm of [[CTY11](#references)] - - [SpaceProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear/provers/space/core.rs#L8) - - [SpaceProductProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear_product/provers/space/core.rs#L11) +Integrated into [WHIR](https://github.com/WizardOfMenlo/whir) ([PR](https://github.com/WizardOfMenlo/whir/pull/250)) and [WARP](https://github.com/compsec-epfl/warp) ([PR](https://github.com/compsec-epfl/warp/pull/24)) with measured performance improvements. Integration capability for streaming contexts like [Jolt](https://github.com/a16z/jolt) is described in [`docs/design.md`](docs/design.md). -- The linear time and linear space algorithm of [[VSBW13](#references)] - - [TimeProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear/provers/time/core.rs#L7) - - [TimeProductProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear_product/provers/time/core.rs#L16) +## Correctness -- The linear time and sublinear space algorithms of [[CFFZ24](#references)] and [[BCFFMMZ25](#references)] respectively - - [BlendyProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear/provers/blendy/core.rs#L14) - - [BlendyProductProver](https://github.com/compsec-epfl/efficient-sumcheck/blob/main/src/multilinear_product/provers/blendy/core.rs#L13) +🚧 Undergoing fuzzing over randomized inputs against [z-tech/sumcheck-lean4](https://github.com/z-tech/sumcheck-lean4), an oracle with machine-checked proofs of completeness and soundness. Findings to follow. ## References -[[LFNK92](https://dl.acm.org/doi/pdf/10.1145/146585.146605)]: Carsten Lund, Lance Fortnow, Howard J. Karloff, and Noam Nisan. “Algebraic Methods for Interactive Proof Systems”. In: Journal of the ACM 39.4 (1992). -[[CTY11](https://arxiv.org/pdf/1109.6882.pdf)]: Graham Cormode, Justin Thaler, and Ke Yi. “Verifying computations with streaming interactive proofs”. In: Proceedings of the VLDB Endowment 5.1 (2011), pp. 25–36. +[[LFKN92](https://dl.acm.org/doi/pdf/10.1145/146585.146605)]: Carsten Lund, Lance Fortnow, Howard J. Karloff, and Noam Nisan. "Algebraic Methods for Interactive Proof Systems". In: Journal of the ACM 39.4 (1992). -[[VSBW13](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6547112)]: Victor Vu, Srinath Setty, Andrew J. Blumberg, and Michael Walfish. “A hybrid architecture for interactive verifiable computation”. In: Proceedings of the 34th IEEE Symposium on Security and Privacy. Oakland ’13. 2013, pp. 223–237. +[[CTY11](https://arxiv.org/pdf/1109.6882.pdf)]: Graham Cormode, Justin Thaler, and Ke Yi. "Verifying computations with streaming interactive proofs". In: Proceedings of the VLDB Endowment 5.1 (2011), pp. 25-36. + +[[VSBW13](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6547112)]: Victor Vu, Srinath Setty, Andrew J. Blumberg, and Michael Walfish. "A hybrid architecture for interactive verifiable computation". In: Proceedings of the 34th IEEE Symposium on Security and Privacy. Oakland '13. 2013, pp. 223-237. [[CFFZ24](https://eprint.iacr.org/2024/524.pdf)]: Alessandro Chiesa, Elisabetta Fedele, Giacomo Fenzi, Andrew Zitek-Estrada. "A time-space tradeoff for the sumcheck prover". In: Cryptology ePrint Archive. -[[BCFFMMZ25](https://eprint.iacr.org/2025/1473.pdf)]: Anubhav Bawejal, Alessandro Chiesa, Elisabetta Fedele, Giacomo Fenzi, Pratyush Mishra, Tushar Mopuri, and Andrew Zitek-Estrada. "Time-Space Trade-Offs for Sumcheck". In: TCC Theory of Cryptography: 23rd International Conference, pp. 37. +[[BCFFMMZ25](https://eprint.iacr.org/2025/1473.pdf)]: Anubhav Baweja, Alessandro Chiesa, Elisabetta Fedele, Giacomo Fenzi, Pratyush Mishra, Tushar Mopuri, and Andrew Zitek-Estrada. "Time-Space Trade-Offs for Sumcheck". In: TCC Theory of Cryptography: 23rd International Conference, pp. 37. + +[[Thaler23](https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf)]: Justin Thaler. "Proofs, Arguments, and Zero-Knowledge". Chapter 4: Interactive Proofs. July 2023. + +[[BDDT25](https://eprint.iacr.org/2025/1117.pdf)]: Aarushi Bagad, Quang Dao, Yuri Domb, and Justin Thaler. "Speeding Up Sum-Check Proving". Cryptology ePrint Archive, 2025/1117. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..ede67c79 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,82 @@ +# Security Policy + +## Audit status + +This library has **not undergone a formal security audit**. It is research-grade +software under active development. + +## Threat model + +The sumcheck protocol is **public-coin**: the prover's computation depends only +on public polynomial evaluations and verifier challenges. No secret values flow +through the prover's arithmetic, so timing side channels in the field operations +do not leak private information. + +If zero-knowledge sumcheck (blinded/masked variants) is added in the future, a +timing analysis of the field arithmetic layer would be warranted. The +fixed-size Montgomery multiplication used for Goldilocks is inherently +data-independent, but this property has not been formally verified. + +## Oracle check responsibility + +`sumcheck_verify` checks round consistency and returns +`SumcheckResult { challenges, final_claim }`. It does **not** verify +that `final_claim == g(r_1, ..., r_v)` — this oracle check ([Thaler Remark 4.2](https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf)) is the caller's responsibility. + +**Forgetting the oracle check is a soundness bug.** A malicious prover +can craft round polynomials that pass all consistency checks but reduce +to an arbitrary final claim. Without the oracle check, the verifier +accepts. + +Correct usage depends on the protocol context: + +| Context | What the caller must do | +|---------|------------------------| +| Standalone | `assert_eq!(result.final_claim, proof.final_value)` | +| Composed (WHIR, GKR) | Pass `result.final_claim` to the next layer, which checks it | +| Custom (WARP) | Compute expected value from `result.challenges` and compare | + +The library intentionally does not bundle the oracle check into the +verifier because every real-world caller handles it differently — and +a closure-based design that most callers bypass with a no-op provides +false safety. Returning `final_claim` directly makes the caller's +obligation explicit. + +## `unsafe` code + +Outside of the SIMD path, the library contains **no `unsafe` code**. Within the +SIMD subsystem, `unsafe` is confined to two categories: + +1. **SIMD intrinsics** (`core::arch`) — `_mm512_loadu_si512`, `vld1q_u64`, etc. + These are `unsafe` by definition in Rust; there is no safe alternative. They + appear exclusively in the backend kernels (`avx512.rs`, `neon.rs`) and the + evaluate/reduce loops that call them. + +2. **Field ↔ `u64` reinterpretation** — arkworks field types don't derive + `zerocopy`, so the blanket `SumcheckField` impl for `ark_ff::Field` uses + `transmute_copy` and `from_raw_parts` to reinterpret Goldilocks elements as + their underlying Montgomery-form `u64` values. These are centralized in five + trait methods (`_to_raw_u64`, `_from_raw_u64`, `_as_u64_slice`, + `_as_u64_slice_mut`, `_from_u64_components`) in `field.rs`, each with a + SAFETY comment documenting the invariant. The dispatch layer itself contains + no `unsafe`. + +Non-arkworks types using `SimdRepr` avoid category 2 entirely — the `zerocopy` +bounds (`IntoBytes + FromBytes + Immutable`) provide compile-time layout +verification, so no `unsafe` reinterpretation is needed. + +The scalar (non-SIMD) code path uses no `unsafe` at all. If `F` is not a +recognised Goldilocks field, SIMD dispatch is skipped and the entire protocol +runs in safe Rust. + +## Reporting a vulnerability + +If you discover a security issue, please report it responsibly via one of: + +- **GitHub:** Use [private vulnerability reporting](https://github.com/compsec-epfl/space-efficient-sumcheck/security/advisories/new) on this repository +- **Email:** andrew.zitek@epfl.ch (subject: `[effsc] Security vulnerability report`) + +Please include a description of the issue, its potential impact, and steps to +reproduce if applicable. You will receive an acknowledgement within 72 hours. + +**Do not** open a public GitHub issue for security vulnerabilities. diff --git a/benches/explanation.rs b/benches/explanation.rs deleted file mode 100644 index 5323a06e..00000000 --- a/benches/explanation.rs +++ /dev/null @@ -1,5 +0,0 @@ -fn main() { - eprintln!("Error: This project uses a custom benchmarking workflow."); - eprintln!("Please navigate to the appropriate bench directory and call the shell './run_bench.sh' directly."); - std::process::exit(1); -} diff --git a/benches/lag-poly-benches/Cargo.lock b/benches/lag-poly-benches/Cargo.lock deleted file mode 100644 index 424bb95d..00000000 --- a/benches/lag-poly-benches/Cargo.lock +++ /dev/null @@ -1,443 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "ahash" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - -[[package]] -name = "ark-bn254" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" -dependencies = [ - "ark-ec", - "ark-ff", - "ark-std", -] - -[[package]] -name = "ark-ec" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" -dependencies = [ - "ahash", - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "educe", - "fnv", - "hashbrown", - "itertools", - "num-bigint", - "num-integer", - "num-traits", - "zeroize", -] - -[[package]] -name = "ark-ff" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" -dependencies = [ - "ark-ff-asm", - "ark-ff-macros", - "ark-serialize", - "ark-std", - "arrayvec", - "digest", - "educe", - "itertools", - "num-bigint", - "num-traits", - "paste", - "zeroize", -] - -[[package]] -name = "ark-ff-asm" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" -dependencies = [ - "quote", - "syn", -] - -[[package]] -name = "ark-ff-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" -dependencies = [ - "num-bigint", - "num-traits", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ark-poly" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" -dependencies = [ - "ahash", - "ark-ff", - "ark-serialize", - "ark-std", - "educe", - "fnv", - "hashbrown", -] - -[[package]] -name = "ark-serialize" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" -dependencies = [ - "ark-serialize-derive", - "ark-std", - "arrayvec", - "digest", - "num-bigint", -] - -[[package]] -name = "ark-serialize-derive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ark-std" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" -dependencies = [ - "num-traits", - "rand", -] - -[[package]] -name = "arrayvec" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" - -[[package]] -name = "autocfg" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "crypto-common", -] - -[[package]] -name = "educe" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" -dependencies = [ - "enum-ordinalize", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "either" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" - -[[package]] -name = "enum-ordinalize" -version = "4.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" -dependencies = [ - "enum-ordinalize-derive", -] - -[[package]] -name = "enum-ordinalize-derive" -version = "4.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "hashbrown" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" -dependencies = [ - "allocator-api2", -] - -[[package]] -name = "itertools" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" -dependencies = [ - "either", -] - -[[package]] -name = "lag-poly-benches" -version = "0.1.0" -dependencies = [ - "ark-bn254", - "ark-ff", - "ark-poly", - "ark-std", - "space-efficient-sumcheck", -] - -[[package]] -name = "num-bigint" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.20.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" - -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - -[[package]] -name = "ppv-lite86" -version = "0.2.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "proc-macro2" -version = "1.0.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" - -[[package]] -name = "space-efficient-sumcheck" -version = "0.0.2" -dependencies = [ - "ark-ff", - "ark-poly", - "ark-std", -] - -[[package]] -name = "syn" -version = "2.0.98" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "unicode-ident" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" - -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "byteorder", - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/benches/lag-poly-benches/Cargo.toml b/benches/lag-poly-benches/Cargo.toml deleted file mode 100644 index 9b942a60..00000000 --- a/benches/lag-poly-benches/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "lag-poly-benches" -version = "0.1.0" -edition = "2021" - -[dependencies] -ark-ff = "0.5.0" -ark-poly = "0.5.0" -ark-std ="0.5.0" -ark-bn254 = "0.5.0" -space-efficient-sumcheck = { path = "../../." } - diff --git a/benches/lag-poly-benches/run_benches.sh b/benches/lag-poly-benches/run_benches.sh deleted file mode 100755 index 3b116cc9..00000000 --- a/benches/lag-poly-benches/run_benches.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -# We measure (i) wall time; and (ii) maximum resident set size, using the GNU-time facility. - -num_vars=15 -while [ $num_vars -le 30 ]; do - # NOTE FOR NEXT LINE: mac --> "gtime", linux --> "time" - output=`(gtime -v ./target/release/lag-poly-benches $num_vars) 2>&1` - user_time_seconds=$(echo "$output" | grep "User time (seconds):" | awk '{print $4}') - user_time_ms=$(awk "BEGIN {printf \"%.0f\", $user_time_seconds * 1000}") - ram_kilobytes=$(echo "$output" | grep "Maximum resident set size (kbytes)" | awk '{print $6}') - ram_bytes=$(echo "$ram_kilobytes" | awk '{ printf "%.0f", $1 * 1000 }') - echo "graycode, $num_vars, $user_time_ms, $ram_bytes" - num_vars=$((num_vars + 1)) -done - -# NOTE: helpful Unix commands -# -# 1) You can run this shell in the background while piping the output to a file like so: -# nohup ./run_benches.sh &> output_file.txt & -# -# 2) If you need to kill the running process you can find pid with: -# lsof | grep output_file -# Then: -# kill \ No newline at end of file diff --git a/benches/provers.rs b/benches/provers.rs deleted file mode 100644 index 2a0bed36..00000000 --- a/benches/provers.rs +++ /dev/null @@ -1,69 +0,0 @@ -use ark_std::{hint::black_box, time::Duration}; -use criterion::{ - criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, Criterion, -}; - -use efficient_sumcheck::{ - multilinear::TimeProver, - multilinear_product::TimeProductProver, - prover::{ProductProverConfig, Prover, ProverConfig}, - tests::{BenchStream, F128}, - ProductSumcheck, Sumcheck, -}; - -fn get_bench_group(c: &mut Criterion) -> BenchmarkGroup<'_, WallTime> { - let mut group = c.benchmark_group("sumcheck_prover"); - group - .sample_size(10) - .warm_up_time(Duration::from_secs(2)) - .measurement_time(Duration::from_secs(10)); - group -} - -fn time_prover_bench(c: &mut Criterion) { - let num_vars = 24usize; - get_bench_group(c).bench_function("time_prover", |bencher| { - bencher.iter_batched( - || { - let stream = BenchStream::::new(num_vars); - TimeProver::>::new( - > as Prover>::ProverConfig::default( - num_vars, stream, - ), - ) - }, - |mut prover: TimeProver>| { - black_box(Sumcheck::::prove::< - BenchStream, - TimeProver>, - >(&mut prover, &mut ark_std::test_rng())); - }, - BatchSize::LargeInput, - ) - }); -} - -fn time_product_prover_bench(c: &mut Criterion) { - let num_vars = 24usize; - get_bench_group(c).bench_function("time_product_prover", |bencher| { - bencher.iter_batched( - || { - let stream = BenchStream::::new(num_vars); - let streams: Vec> = vec![stream.clone(), stream.clone()]; - TimeProductProver::>::new(ProductProverConfig::default( - num_vars, streams, - )) - }, - |mut prover: TimeProductProver>| { - black_box(ProductSumcheck::::prove::< - BenchStream, - TimeProductProver>, - >(&mut prover, &mut ark_std::test_rng())); - }, - BatchSize::LargeInput, - ) - }); -} - -criterion_group!(benches, time_product_prover_bench, time_prover_bench); -criterion_main!(benches); diff --git a/benches/sumcheck-benches/Cargo.lock b/benches/sumcheck-benches/Cargo.lock deleted file mode 100644 index 0eff14c6..00000000 --- a/benches/sumcheck-benches/Cargo.lock +++ /dev/null @@ -1,454 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "ahash" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - -[[package]] -name = "ark-bn254" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" -dependencies = [ - "ark-ec", - "ark-ff", - "ark-std", -] - -[[package]] -name = "ark-ec" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" -dependencies = [ - "ahash", - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "educe", - "fnv", - "hashbrown", - "itertools", - "num-bigint", - "num-integer", - "num-traits", - "zeroize", -] - -[[package]] -name = "ark-ff" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" -dependencies = [ - "ark-ff-asm", - "ark-ff-macros", - "ark-serialize", - "ark-std", - "arrayvec", - "digest", - "educe", - "itertools", - "num-bigint", - "num-traits", - "paste", - "zeroize", -] - -[[package]] -name = "ark-ff-asm" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" -dependencies = [ - "quote", - "syn", -] - -[[package]] -name = "ark-ff-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" -dependencies = [ - "num-bigint", - "num-traits", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ark-poly" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" -dependencies = [ - "ahash", - "ark-ff", - "ark-serialize", - "ark-std", - "educe", - "fnv", - "hashbrown", -] - -[[package]] -name = "ark-serialize" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" -dependencies = [ - "ark-serialize-derive", - "ark-std", - "arrayvec", - "digest", - "num-bigint", -] - -[[package]] -name = "ark-serialize-derive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "ark-std" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" -dependencies = [ - "num-traits", - "rand", -] - -[[package]] -name = "arrayvec" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" - -[[package]] -name = "autocfg" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "crypto-common", -] - -[[package]] -name = "educe" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" -dependencies = [ - "enum-ordinalize", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "efficient-sumcheck" -version = "0.0.2" -dependencies = [ - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "memmap2", -] - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - -[[package]] -name = "enum-ordinalize" -version = "4.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" -dependencies = [ - "enum-ordinalize-derive", -] - -[[package]] -name = "enum-ordinalize-derive" -version = "4.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "hashbrown" -version = "0.15.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" -dependencies = [ - "allocator-api2", -] - -[[package]] -name = "itertools" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" -dependencies = [ - "either", -] - -[[package]] -name = "libc" -version = "0.2.172" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" - -[[package]] -name = "memmap2" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" -dependencies = [ - "libc", -] - -[[package]] -name = "num-bigint" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" - -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - -[[package]] -name = "ppv-lite86" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "proc-macro2" -version = "1.0.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" - -[[package]] -name = "sumcheck-benches" -version = "0.1.0" -dependencies = [ - "ark-bn254", - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "efficient-sumcheck", -] - -[[package]] -name = "syn" -version = "2.0.101" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "typenum" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" - -[[package]] -name = "unicode-ident" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" - -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "zerocopy" -version = "0.8.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/benches/sumcheck-benches/Cargo.toml b/benches/sumcheck-benches/Cargo.toml deleted file mode 100644 index 21abef8d..00000000 --- a/benches/sumcheck-benches/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "sumcheck-benches" -version = "0.1.0" -edition = "2021" - -[dependencies] -ark-ff = "0.5.0" -ark-poly = "0.5.0" -ark-serialize = "0.5.0" -ark-std ="0.5.0" -ark-bn254 = "0.5.0" -efficient-sumcheck = { path = "../../.", default-features = false } - -[profile.release] -debug = true diff --git a/benches/sumcheck-benches/run_benches.sh b/benches/sumcheck-benches/run_benches.sh deleted file mode 100755 index 339174ac..00000000 --- a/benches/sumcheck-benches/run_benches.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/sh - -# We measure (i) wall time; and (ii) maximum resident set size, using the GNU-time facility. - -algorithms="VSBW Blendy2 Blendy1 Blendy3 Blendy4 CTY ProductBlendy2 ProductVSBW ProductCTY" -fields="Field64 Field128 FieldBn254" - -for algorithm in $algorithms; do - for field in $fields; do - num_vars=16 - while [ $num_vars -le 30 ]; do - case "$algorithm" in - "Blendy1") stage_size="1" ;; - "Blendy2") stage_size="2" ;; - "Blendy3") stage_size="3" ;; - "Blendy4") stage_size="4" ;; - "VSBW") stage_size="1" ;; - "CTY") stage_size="1" ;; - "ProductBlendy2") stage_size="2" ;; - "ProductVSBW") stage_size="1" ;; - "ProductCTY") stage_size="1" ;; - *) ;; - esac - case "$algorithm" in - "Blendy1") algorithm_label="Blendy" ;; - "Blendy2") algorithm_label="Blendy" ;; - "Blendy3") algorithm_label="Blendy" ;; - "Blendy4") algorithm_label="Blendy" ;; - "VSBW") algorithm_label="VSBW" ;; - "CTY") algorithm_label="CTY" ;; - "ProductBlendy2") algorithm_label="ProductBlendy" ;; - "ProductVSBW") algorithm_label="ProductVSBW" ;; - "ProductCTY") algorithm_label="ProductCTY" ;; - *) ;; - esac - # NOTE FOR NEXT LINE: mac --> "gtime", linux --> "time" - output=`(gtime -v ./target/release/sumcheck-benches $algorithm_label $field $num_vars $stage_size) 2>&1` - user_time_seconds=$(echo "$output" | grep "User time (seconds):" | awk '{print $4}') - user_time_ms=$(awk "BEGIN {printf \"%.0f\", $user_time_seconds * 1000}") - ram_kilobytes=$(echo "$output" | grep "Maximum resident set size (kbytes)" | awk '{print $6}') - ram_bytes=$(echo "$ram_kilobytes" | awk '{ printf "%.0f", $1 * 1000 }') - echo "$algorithm, $field, $num_vars, $user_time_ms, $num_vars, $ram_bytes" - num_vars=$((num_vars + 2)) - done - done -done - -# NOTE: helpful Unix commands -# -# 1) You can run this shell in the background while piping the output to a file like so: -# nohup ./run_benches.sh &> output_file.txt & -# -# 2) If you need to kill the running process you can find pid with: -# lsof | grep output_file -# Then: -# kill diff --git a/benches/sumcheck-benches/src/main.rs b/benches/sumcheck-benches/src/main.rs deleted file mode 100644 index 5de34666..00000000 --- a/benches/sumcheck-benches/src/main.rs +++ /dev/null @@ -1,122 +0,0 @@ -use ark_ff::Field; - -use efficient_sumcheck::{ - multilinear::{ - BlendyProver, BlendyProverConfig, ReduceMode, SpaceProver, SpaceProverConfig, TimeProver, - TimeProverConfig, - }, - multilinear_product::{ - BlendyProductProver, BlendyProductProverConfig, SpaceProductProver, - SpaceProductProverConfig, TimeProductProver, TimeProductProverConfig, - }, - prover::{Prover, ProverConfig}, - tests::{BenchStream, F128, F64}, - ProductSumcheck, Sumcheck, -}; - -pub mod validation; -use validation::{validate_and_format_command_line_args, AlgorithmLabel, BenchArgs, FieldLabel}; - -fn run_on_field(bench_args: BenchArgs) { - let mut rng = ark_std::test_rng(); - let s = BenchStream::::new(bench_args.num_variables); - - // switch on algorithm_label - match bench_args.algorithm_label { - AlgorithmLabel::Blendy => { - let config: BlendyProverConfig> = - BlendyProverConfig::>::default(bench_args.num_variables, s); - let transcript = Sumcheck::::prove::, BlendyProver>>( - &mut BlendyProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - AlgorithmLabel::VSBW => { - let config: TimeProverConfig> = - TimeProverConfig::>::new( - bench_args.num_variables, - s, - ReduceMode::Pairwise, - ); - let transcript = Sumcheck::::prove::, TimeProver>>( - &mut TimeProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - AlgorithmLabel::CTY => { - let config: SpaceProverConfig> = - SpaceProverConfig::>::default(bench_args.num_variables, s); - let transcript = Sumcheck::::prove::, SpaceProver>>( - &mut SpaceProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - AlgorithmLabel::ProductVSBW => { - let config: TimeProductProverConfig> = - TimeProductProverConfig::>::new( - bench_args.num_variables, - vec![s.clone(), s], - ); - let transcript = ProductSumcheck::::prove::< - BenchStream, - TimeProductProver>, - >( - &mut TimeProductProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - AlgorithmLabel::ProductBlendy => { - let config: BlendyProductProverConfig> = - BlendyProductProverConfig::>::new( - bench_args.num_variables, - bench_args.stage_size, - vec![s.clone(), s], - ); - let transcript = ProductSumcheck::::prove::< - BenchStream, - BlendyProductProver>, - >( - &mut BlendyProductProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - AlgorithmLabel::ProductCTY => { - let config: SpaceProductProverConfig> = - SpaceProductProverConfig::>::new( - bench_args.num_variables, - vec![s.clone(), s], - ); - let transcript = ProductSumcheck::::prove::< - BenchStream, - SpaceProductProver>, - >( - &mut SpaceProductProver::>::new(config), - &mut rng, - ); - assert!(transcript.is_accepted); - } - }; -} - -fn main() { - // Collect command line arguments - let bench_args: BenchArgs = validate_and_format_command_line_args(std::env::args().collect()); - // Run the requested bench - match bench_args.field_label { - FieldLabel::Field64 => { - run_on_field::(bench_args); - } - FieldLabel::Field128 => { - run_on_field::(bench_args); - } - FieldLabel::FieldBn254 => { - // run_on_field::(bench_args); - run_on_field::(bench_args); - } - }; -} diff --git a/benches/sumcheck-benches/src/validation/mod.rs b/benches/sumcheck-benches/src/validation/mod.rs deleted file mode 100644 index ff6f00e2..00000000 --- a/benches/sumcheck-benches/src/validation/mod.rs +++ /dev/null @@ -1,104 +0,0 @@ -#[derive(Debug)] -pub enum FieldLabel { - Field64, - Field128, - FieldBn254, -} - -#[derive(Debug)] -pub enum AlgorithmLabel { - CTY, - VSBW, - Blendy, - ProductBlendy, - ProductVSBW, - ProductCTY, -} - -pub struct BenchArgs { - pub field_label: FieldLabel, - pub algorithm_label: AlgorithmLabel, - pub num_variables: usize, - pub stage_size: usize, -} - -fn check_if_number(input_string: String) -> bool { - match input_string.parse::() { - Ok(_) => true, - Err(_) => false, - } -} - -pub fn validate_and_format_command_line_args(argsv: Vec) -> BenchArgs { - // Check if the correct number of command line arguments is provided - if argsv.len() != 5 { - eprintln!( - "Usage: {} field_label algorithm_label num_variables stage_size", - argsv[0] - ); - std::process::exit(1); - } - // algorithm label - if !(argsv[1] == "CTY" - || argsv[1] == "VSBW" - || argsv[1] == "Blendy" - || argsv[1] == "ProductBlendy" - || argsv[1] == "ProductVSBW" - || argsv[1] == "ProductCTY") - { - eprintln!( - "Usage: {} field_label algorithm_label num_variables stage_size", - argsv[0] - ); - eprintln!("Invalid input: algorithm_label must be one of (CTY, VSBW, Blendy, ProductVSBW, ProductBlendy, ProductCTY)"); - std::process::exit(1); - } - let algorithm_label = match argsv[1].as_str() { - "CTY" => AlgorithmLabel::CTY, - "VSBW" => AlgorithmLabel::VSBW, - "Blendy" => AlgorithmLabel::Blendy, - "ProductVSBW" => AlgorithmLabel::ProductVSBW, - "ProductCTY" => AlgorithmLabel::ProductCTY, - _ => AlgorithmLabel::ProductBlendy, // this is checked in previous line - }; - // field_label - if !(argsv[2] == "Field64" || argsv[2] == "Field128" || argsv[2] == "FieldBn254") { - eprintln!( - "Usage: {} field_label algorithm_label num_variables stage_size", - argsv[0] - ); - eprintln!("Invalid input: field_label must be one of (Field64, Field128, FieldBn254)"); - std::process::exit(1); - } - let field_label = match argsv[2].as_str() { - "Field64" => FieldLabel::Field64, - "Field128" => FieldLabel::Field128, - _ => FieldLabel::FieldBn254, // this is checked in previous line - }; - // num_variables - if !check_if_number(argsv[3].clone()) { - eprintln!( - "Usage: {} field_label algorithm_label num_variables stage_size", - argsv[0] - ); - eprintln!("Invalid input: num_variables must be a number"); - std::process::exit(1); - } - let num_variables = argsv[3].clone().parse::().unwrap(); - // stage_size - if !check_if_number(argsv[4].clone()) { - eprintln!( - "Usage: {} field_label algorithm_label num_variables stage_size", - argsv[0] - ); - eprintln!("Invalid input: stage_size must be a number"); - std::process::exit(1); - } - let stage_size = argsv[4].clone().parse::().unwrap(); - return BenchArgs { - field_label, - algorithm_label, - num_variables, - stage_size, - }; -} diff --git a/benches/sumcheck.rs b/benches/sumcheck.rs new file mode 100644 index 00000000..8a369bd5 --- /dev/null +++ b/benches/sumcheck.rs @@ -0,0 +1,190 @@ +//! Sumcheck benchmarks for the canonical API. +//! +//! Matrix: {multilinear, inner_product} × {F64, F64Ext3} × {2^16, 2^20, 2^24} +//! Plus: fold kernel throughput. +//! +//! Run: cargo bench --bench sumcheck +//! AVX: RUSTFLAGS="-C target-feature=+avx512ifma" cargo bench --bench sumcheck + +use ark_ff::UniformRand; +use ark_std::{hint::black_box, time::Duration}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; + +use effsc::provers::inner_product::InnerProductProver; +use effsc::provers::multilinear::MultilinearProver; +use effsc::runner::sumcheck; +use effsc::tests::{F64Ext3, F64}; +use effsc::transcript::SanityTranscript; + +const SIZES: [usize; 3] = [16, 20, 24]; + +fn multilinear_f64(c: &mut Criterion) { + let mut g = c.benchmark_group("multilinear/F64"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + (0..n).map(|_| F64::rand(&mut rng)).collect::>() + }, + |evals| { + let mut p = MultilinearProver::new(evals); + let mut rng = ark_std::test_rng(); + let mut t = SanityTranscript::new(&mut rng); + black_box(sumcheck(&mut p, nv, &mut t, |_, _| {})); + }, + ); + }); + } + g.finish(); +} + +fn multilinear_ext3(c: &mut Criterion) { + let mut g = c.benchmark_group("multilinear/F64Ext3"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + (0..n).map(|_| F64Ext3::rand(&mut rng)).collect::>() + }, + |evals| { + let mut p = MultilinearProver::new(evals); + let mut rng = ark_std::test_rng(); + let mut t = SanityTranscript::new(&mut rng); + black_box(sumcheck(&mut p, nv, &mut t, |_, _| {})); + }, + ); + }); + } + g.finish(); +} + +fn inner_product_f64(c: &mut Criterion) { + let mut g = c.benchmark_group("inner_product/F64"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + let a: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + (a, b) + }, + |(a, b)| { + let mut p = InnerProductProver::new(a, b); + let mut rng = ark_std::test_rng(); + let mut t = SanityTranscript::new(&mut rng); + black_box(sumcheck(&mut p, nv, &mut t, |_, _| {})); + }, + ); + }); + } + g.finish(); +} + +fn inner_product_ext3(c: &mut Criterion) { + let mut g = c.benchmark_group("inner_product/F64Ext3"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + let a: Vec = (0..n).map(|_| F64Ext3::rand(&mut rng)).collect(); + let b: Vec = (0..n).map(|_| F64Ext3::rand(&mut rng)).collect(); + (a, b) + }, + |(a, b)| { + let mut p = InnerProductProver::new(a, b); + let mut rng = ark_std::test_rng(); + let mut t = SanityTranscript::new(&mut rng); + black_box(sumcheck(&mut p, nv, &mut t, |_, _| {})); + }, + ); + }); + } + g.finish(); +} + +fn fold_f64(c: &mut Criterion) { + let mut g = c.benchmark_group("fold/F64"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let w = F64::rand(&mut rng); + (evals, w) + }, + |(mut evals, w)| { + effsc::fold(&mut evals, w); + black_box(evals); + }, + ); + }); + } + g.finish(); +} + +fn fold_ext3(c: &mut Criterion) { + let mut g = c.benchmark_group("fold/F64Ext3"); + g.sample_size(10) + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(5)); + for nv in SIZES { + let n = 1usize << nv; + g.throughput(Throughput::Elements(n as u64)); + g.bench_function(BenchmarkId::from_parameter(format!("2^{nv}")), |b| { + b.iter_with_setup( + || { + let mut rng = ark_std::test_rng(); + let evals: Vec = (0..n).map(|_| F64Ext3::rand(&mut rng)).collect(); + let w = F64Ext3::rand(&mut rng); + (evals, w) + }, + |(mut evals, w)| { + effsc::fold(&mut evals, w); + black_box(evals); + }, + ); + }); + } + g.finish(); +} + +criterion_group!( + benches, + multilinear_f64, + multilinear_ext3, + inner_product_f64, + inner_product_ext3, + fold_f64, + fold_ext3, +); +criterion_main!(benches); diff --git a/docs/compatibility.md b/docs/compatibility.md new file mode 100644 index 00000000..f8b8e221 --- /dev/null +++ b/docs/compatibility.md @@ -0,0 +1,88 @@ +# Non-arkworks field support + +By default, the library provides a blanket `SumcheckField` implementation for +all `ark_ff::Field` types. Non-arkworks users can compile with +`--no-default-features` and implement the trait for their own field type. + +## `SumcheckField` trait + +```rust +pub trait SumcheckField: + Copy + Send + Sync + PartialEq + Debug + + Add + Sub + Mul + Neg + AddAssign + SubAssign + MulAssign + + Sum + 'static +{ + const ZERO: Self; + const ONE: Self; + fn from_u64(val: u64) -> Self; + fn inverse(&self) -> Option; +} +``` + +## Example: Plonky3 Goldilocks + +See [`tests/plonky3_roundtrip.rs`](../tests/plonky3_roundtrip.rs) for a complete +working example — a newtype wrapper around `p3_goldilocks::Goldilocks` that +implements `SumcheckField` and runs a full prove + verify roundtrip with no +arkworks dependency. + +## SIMD opt-in via `SimdRepr` + +To enable SIMD acceleration for a non-arkworks Goldilocks type, implement +`SimdRepr`. The `zerocopy` bounds (`IntoBytes + FromBytes + Immutable`) provide +compile-time layout verification — no `unsafe` needed from the implementor. + +```rust +#[derive(Clone, Copy, Debug, PartialEq, + zerocopy::IntoBytes, zerocopy::FromBytes, zerocopy::Immutable)] +#[repr(transparent)] +struct MyGoldilocks(u64); + +impl SumcheckField for MyGoldilocks { + const ZERO: Self = MyGoldilocks(0); + const ONE: Self = MyGoldilocks(1); // Montgomery form of 1 + fn from_u64(val: u64) -> Self { /* ... */ } + fn inverse(&self) -> Option { /* ... */ } + fn _simd_field_config() -> Option { + Some(SimdFieldConfig { modulus: GOLDILOCKS_P, element_bytes: 8 }) + } +} + +impl SimdRepr for MyGoldilocks { + fn modulus() -> u64 { GOLDILOCKS_P } +} +``` + +Extension fields work the same way: + +```rust +#[derive(Clone, Copy, Debug, PartialEq, + zerocopy::IntoBytes, zerocopy::FromBytes, zerocopy::Immutable)] +#[repr(transparent)] +struct MyExt3([u64; 3]); + +impl SumcheckField for MyExt3 { + fn extension_degree() -> u64 { 3 } + fn _simd_field_config() -> Option { + Some(SimdFieldConfig { modulus: GOLDILOCKS_P, element_bytes: 8 }) + } + // ... +} + +impl SimdRepr for MyExt3 { + fn modulus() -> u64 { GOLDILOCKS_P } +} +``` + +## Feature flags + +```toml +[features] +default = ["arkworks", "parallel"] +arkworks = ["ark-ff", "ark-poly", "ark-serialize", "ark-std", "spongefish"] +parallel = ["rayon"] +``` + +- `arkworks` (default): blanket `SumcheckField` impl for `ark_ff::Field` +- `parallel` (default): rayon parallelism for fold and round computation +- `--no-default-features`: pure `SumcheckField` library, no arkworks dependency diff --git a/docs/design.md b/docs/design.md new file mode 100644 index 00000000..bcc6f8b0 --- /dev/null +++ b/docs/design.md @@ -0,0 +1,763 @@ +# Sumcheck API Design + +Authoritative reference: Justin Thaler, *Proofs, Arguments, and Zero-Knowledge*, +Chapter 4 ("Interactive Proofs"), July 2023. +https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf + +## 1. The protocol (Thaler §4.1) + +Given a v-variate polynomial g over a finite field F with degree at most d +in each variable, the sum-check protocol proves: + +``` +H = sum_{b in {0,1}^v} g(b_1, ..., b_v) +``` + +The protocol proceeds in v rounds. Each round j: + +1. **P -> V**: a univariate polynomial g_j(X_j) of degree <= deg_j(g), + specified by its evaluations at {0, 1, ..., deg_j(g)}. +2. **V checks**: g_{j-1}(r_{j-1}) = g_j(0) + g_j(1). + (For round 1, checks H = g_1(0) + g_1(1).) +3. **V -> P**: random r_j in F. + +After round v: V checks g_v(r_v) = g(r_1, ..., r_v) via an oracle query +(or by delegation to another protocol). + +**Proposition 4.1.** Completeness error 0, soundness error <= vd / |F|. + +### What the prover sends + +Table 4.1: total prover-to-verifier communication is +sum_{j=1}^{v} (deg_j(g) + 1) field elements. When deg_j(g) = O(1) for +all j (the common case), this is O(v) field elements. + +### What the verifier checks + +Per round: one consistency equation (g_j(0) + g_j(1) = previous claim) +and a degree bound (deg g_j <= deg_j(g)). After all rounds: one oracle +query to g at a random point. + +Total verifier time: O(v + sum_j deg_j(g)) + T, where T is the cost of +one evaluation of g at a point in F^v. + +## 2. Our three use cases are one protocol + +Thaler's protocol is parameterized by g. Our current three entry points +are three instantiations: + +| Use case | g | deg per var | Ref | +|-------------------------|-----------------------------|-------------|------------| +| multilinear_sumcheck | f_tilde (MLE of evals) | 1 | §4.1 | +| inner_product_sumcheck | f_tilde * g_tilde | 2 | §4.4 | +| coefficient_sumcheck | user-defined | d | §4.6 | + +In all three cases, the *protocol* is identical — only the prover's +round-polynomial computation changes. This motivates a single protocol +runner parameterized by a prover trait, not three separate functions. + +## 3. The prover trait + +```rust +/// Prover side of the sum-check protocol (Thaler §4.1). +/// +/// Implementors define how the round polynomial g_j is computed +/// from the prover's internal state. The protocol runner calls +/// `round()` once per round, then the caller inspects post-state. +pub trait SumcheckProver { + /// Degree of g_j in the current variable X_j. + fn degree(&self) -> usize; + + /// Compute g_j and advance state. + /// + /// Returns evaluations of g_j at {0, 1, ..., degree()}. + /// `challenge` is `None` for round 0 (no prior challenge exists); + /// `Some(r_{j-1})` for rounds j >= 1, used to fold internal state. + fn round(&mut self, challenge: Option) -> Vec; + + /// Apply the final verifier challenge. + /// Called once after the last round, before `final_value()`. + fn finalize(&mut self, last_challenge: F); + + /// After finalize(): the claimed value g(r_1, ..., r_v). + fn final_value(&self) -> F; +} +``` + +The prover is passed as `&mut P` to the protocol runner. After sumcheck +completes, the caller retains ownership and can query prover-specific +post-state (e.g., for GKR: the two claimed W values at b* and c*). + +## 4. Prover strategies and streaming + +### Two axes: space strategy and variable ordering + +The prover has two independent design choices: + +1. **Space strategy** — how much memory to budget: + - *Time*: O(2^v) space, O(2^v) total time. Holds all evaluations. + - *Blendy*: O(2^k) space, O(2^v) total time (k < v). Partitions + variables into stages of size k, recomputes per stage. + - *Space*: O(v) space, O(v · 2^v) total time. Academic only. + +2. **Variable ordering** — which variable to fold each round: + - *MSB* (half-split): fold the topmost variable. Pairs `(v[k], v[k+L/2])`. + Best for in-memory and random-access-streaming workloads. + - *LSB* (pair-split): fold the bottommost variable. Pairs `(v[2k], v[2k+1])`. + Best for sequential-streaming workloads where data arrives incrementally. + +These choices are orthogonal: you can use blendy with MSB ordering (large +witness on SSD) or blendy with LSB ordering (Jolt CPU trace). + +### Streaming taxonomy + +The choice of variable ordering depends on how data is available: + +| Scenario | Data availability | Access | Best ordering | Strategy | +|----------|-------------------|--------|---------------|----------| +| In-memory | Full table in RAM | Random | MSB | Time | +| Random-access stream | Exists on disk, too big for RAM | Seekable | MSB | Blendy | +| Sequential stream | Generated incrementally | Forward-only | LSB | Blendy | + +**Random-access streaming** (e.g., large witness mmap'd from SSD): the data +exists but doesn't fit in RAM. MSB ordering has better cache behavior because +it reads two contiguous half-table regions; the blendy working set fits in +cache while the full table is paged in as needed. + +**Sequential streaming** (e.g., Jolt CPU trace): evaluations are computed +on-the-fly and arrive in index order (0, 1, 2, ...). LSB ordering is optimal +because adjacent pairs `(f[2k], f[2k+1])` are immediately available — +the prover can begin folding before the full table exists. + +In both streaming cases, blendy is used because the full table doesn't fit +in the working set. Blendy is the *space strategy*; MSB vs LSB is the +*traversal order*. They are independent. + +### Blendy stage scheduling + +Standard blendy (CFFZ24) partitions v variables into stages of fixed size k, +recomputing the partial-sum table once per stage. The optimal k depends on +the ratio of cache sizes to element size. + +Jolt's `HalfSplitSchedule` (based on BCFFMMZ25, eprint 2025/1473) takes a +different approach: **cost-model-driven, non-uniform window sizes**. + +For a degree-d sumcheck, the cost of processing a window of w variables +starting at round i is `(d+1)^w / 2^(w+i) * T` where T is the trace +length. Setting cost ~ 1 gives optimal window size: + +``` +w(i) = round(ratio * i) where ratio = ln(2) / ln((d+1)/2) +``` + +This produces growing windows: early rounds use small windows (the +hypercube is large), later rounds use larger windows (the residual sum is +small). For degree 2, the windows grow as 1, 2, 5, 14, ... + +| Degree | Ratio | Example window sequence | +|--------|-------|------------------------| +| 2 | 1.71 | 1, 2, 5, 14, ... | +| 3 | 1.00 | 1, 1, 2, 3, 4, ... | +| 4 | 0.76 | 1, 1, 1, 2, 2, 3, ... | + +The schedule is parameterized by a `StreamingSchedule` trait, not a fixed +constant: + +```rust +pub trait StreamingSchedule { + fn num_rounds(&self) -> usize; + fn switch_over_point(&self) -> usize; + fn is_window_start(&self, round: usize) -> bool; + fn num_unbound_vars(&self, round: usize) -> usize; +} +``` + +This allows tuning per deployment target and polynomial degree. + +### Strategies table + +Per Thaler §4.4.3 and prior work (CTY11, VSBW13, CFFZ24, BCFFMMZ25): + +| Strategy | Space | Total time | Input | Ordering | Ref | +|----------|-----------|--------------|--------------|----------|------------------| +| Time | O(2^v) | O(2^v) | Vec\ | MSB | VSBW13 | +| Blendy | O(2^k) | O(2^v) | Stream\ | MSB or LSB | CFFZ24, BCFFMMZ25 | +| Space | O(v) | O(v * 2^v) | Stream\ | MSB or LSB | CTY11 | + +All implement `SumcheckProver`. The difference is internal: +how `round()` computes the polynomial from the data. + +### Construction + +```rust +// In-memory (MSB, time strategy). +impl MultilinearProver { + pub fn new(evals: Vec) -> Self; +} + +// Streaming (LSB or MSB, blendy strategy). +impl StreamingMultilinearProver { + /// Random-access stream, MSB ordering. Best for mmap'd data. + pub fn new_msb>(stream: S, k: usize) -> Self; + + /// Sequential stream, LSB ordering. Best for incremental data (Jolt). + pub fn new_lsb>(stream: S, k: usize) -> Self; +} +``` + +The `Stream` trait (already in `src/streams/`) provides random access to +evaluations without requiring the full table in memory. + +## 5. Three polynomial shapes + +Each polynomial shape (multilinear, product, general) has its own prover +type for each strategy, but all implement `SumcheckProver`: + +```rust +/// g = f_tilde, degree 1. Prover folds evals via Lemma 4.3. +pub struct MultilinearProver { evals: Vec } + +/// g = f_tilde * g_tilde, degree 2. Prover folds both vectors. +pub struct InnerProductProver { a: Vec, b: Vec } + +/// g = user-defined, degree d. Wraps a RoundPolyEvaluator. +pub struct CoefficientProver> { ... } +``` + +### The time prover's round() (Lemma 4.3 and 4.5) + +For a multilinear f over {0,1}^v with evaluations in array A: + +``` +A[x'] = r_1 * A[1, x'] + (1 - r_1) * A[0, x'] +``` + +This is the **fold** operation (Lemma 4.3, equation 4.13). After folding, +the array has half the entries, and the prover can compute the next round +polynomial from the folded array. + +For a product of k multilinears (Lemma 4.5), the same fold is applied to +each factor independently, and the round polynomial is computed from the +folded factors. + +## 6. The fold primitive + +The fold operation (Thaler Lemma 4.3, equation 4.13) is the computational +core of the time prover: + +```rust +/// Half-split (MSB) fold: new[k] = v[k] + weight * (v[k + L/2] - v[k]). +/// +/// Implicit zero padding for non-power-of-two inputs. +/// SIMD-accelerated for Goldilocks on NEON and AVX-512 IFMA. +pub fn fold(values: &mut Vec, weight: F); +``` + +The fold is exposed as a standalone public function because callers +(e.g., WHIR's `multilinear_fold`) need it independently of the full +sumcheck protocol. + +### Layout: half-split (MSB) + +We fold the *top-most* variable each round: `v[0..L/2]` vs `v[L/2..L]`. +This is the MSB (most-significant-bit-first) convention. It matches +Thaler's equation 4.13 directly: + +``` +p(x_1, ..., x_l) = x_1 * p(1, x_2, ..., x_l) + (1 - x_1) * p(0, x_2, ..., x_l) +``` + +where `p(1, ...)` is the upper half and `p(0, ...)` is the lower half. + +### SIMD opt-in for non-arkworks types (`SimdRepr`) + +Non-arkworks Goldilocks implementations opt into SIMD via the `SimdRepr` +trait, whose memory layout guarantee is enforced at compile time by +`zerocopy`: + +```rust +pub trait SimdRepr: + SumcheckField + zerocopy::IntoBytes + zerocopy::FromBytes + zerocopy::Immutable +{ + fn modulus() -> u64; +} +``` + +The zerocopy bounds verify at derive time that the type supports safe byte +reinterpretation. No `unsafe` needed from the implementor. A wrong +`modulus()` produces wrong arithmetic (logic bug), not UB. + +```rust +#[derive(zerocopy::IntoBytes, zerocopy::FromBytes, zerocopy::Immutable)] +#[repr(transparent)] +struct JoltGoldilocks(u64); + +impl SimdRepr for JoltGoldilocks { + fn modulus() -> u64 { GOLDILOCKS_P } +} +``` + +Arkworks types bypass `SimdRepr` — the blanket `SumcheckField` impl +auto-detects Goldilocks from `BasePrimeField::MODULUS` and the detection +is const-folded by LLVM. + +## 7. The protocol runner + +```rust +/// Run the sum-check protocol for `num_rounds` rounds. +/// +/// Partial execution (`num_rounds < v`) supports composed protocols +/// like GKR (one sumcheck per circuit layer) and WHIR (partial rounds +/// interleaved with commit/open). +/// +/// `hook` is called each round after the prover message is written +/// and before the verifier challenge is read. Pass `|_, _| {}` when +/// no hook is needed. +pub fn sumcheck( + prover: &mut P, + num_rounds: usize, + transcript: &mut T, + hook: H, +) -> SumcheckProof +where + F: Field, + T: Transcript, + H: FnMut(usize, &mut T), + P: SumcheckProver; +``` + +### The proof transcript + +```rust +pub struct SumcheckProof { + /// g_j evaluations per round: round_polys[j] has degree + 1 entries. + pub round_polynomials: Vec>, + /// Verifier challenges r_1, ..., r_v. + pub challenges: Vec, + /// g(r_1, ..., r_v) -- the prover's claimed final evaluation. + pub final_value: F, +} +``` + +This matches the protocol description exactly. The verifier can +reconstruct any round's consistency check from `round_polynomials` +and `challenges`. + +## 8. The verifier + +One function that checks any sumcheck proof, regardless of the +polynomial's degree: + +```rust +/// Verify a sum-check proof against a claimed sum. +/// +/// Checks per round: g_j(0) + g_j(1) = previous claim, and +/// deg(g_j) <= expected_degree. +/// +/// Returns the final claimed value and the challenge vector on success. +/// The caller is responsible for the oracle check: verifying that +/// final_value = g(r_1, ..., r_v). How this is done depends on the +/// application (direct evaluation, delegation to another sumcheck, +/// polynomial commitment query, etc.). +pub fn sumcheck_verify( + claimed_sum: F, + expected_degree: usize, + num_rounds: usize, + transcript: &mut T, + hook: H, +) -> Result<(F, Vec), SumcheckError> +where + F: Field, + T: Transcript, + H: FnMut(usize, &mut T); +``` + +The verifier does NOT perform the final oracle check. Per Remark 4.2, +the verifier can apply sumcheck "even without knowing the polynomial g." +The final check is the caller's responsibility because it depends on the +application: + +- Standalone sumcheck: evaluate g(r) directly. +- GKR: delegate to the next layer's sumcheck. +- WHIR: check via polynomial commitment. +- MatMult: evaluate f_A and f_B at derived points. + +## 9. Cross-field support (BF != EF) + +In practice, evaluations often live in a small base field BF (e.g., +Goldilocks, p = 2^64 - 2^32 + 1) while challenges are sampled from a +larger extension field EF (e.g., Goldilocks^3) for soundness. + +This is a prover concern, not a protocol concern: + +- Round 0: the prover computes g_1 over BF evaluations. +- After receiving r_1 in EF: the prover performs a cross-field fold, + lifting BF data to EF. +- Rounds 1+: everything is in EF. + +The `SumcheckProver` trait is generic over a single field `F` (= EF). +The BF -> EF transition happens inside the prover's `round()` method +when `challenge` transitions from `None` to `Some(r_1)`. The protocol +runner and verifier never see BF. + +A convenience constructor handles the common case: + +```rust +impl> MultilinearProver { + /// Cross-field prover: evaluations in BF, challenges in EF. + /// Round 0 computes in BF, then lifts to EF on first challenge. + pub fn cross_field(evals: Vec) -> Self; +} +``` + +## 10. GKR compatibility (§4.6) + +GKR runs d sumcheck invocations (one per circuit layer). Each layer's +sumcheck is over a different polynomial f_r^(i) (equation 4.18): + +``` +f_r^(i)(b, c) = add_i(r_i, b, c) * (W_{i+1}(b) + W_{i+1}(c)) + + mult_i(r_i, b, c) * (W_{i+1}(b) * W_{i+1}(c)) +``` + +This is a (2k_{i+1})-variate polynomial of degree 2 in each variable. + +### What GKR needs from sumcheck + +1. **Custom polynomial**: GKR defines its own round polynomial via the + wiring predicates (add_i, mult_i). This is a custom `SumcheckProver` + implementation -- the trait handles it. + +2. **Partial execution**: each layer runs a full sumcheck (all rounds), + but the *claim chains* between layers. The `sumcheck()` function + returns, then GKR processes the result and starts a new sumcheck. + +3. **Post-state inspection**: after sumcheck, GKR needs the prover's + claimed values W_{i+1}(b*) and W_{i+1}(c*). Since `prover` is + `&mut P`, the caller retains the prover and calls GKR-specific + methods on the concrete type: + + ```rust + let proof = sumcheck(&mut gkr_prover, num_rounds, &mut t, |_, _| {}); + let (w_b, w_c) = gkr_prover.claimed_w_values(); // GKR-specific method + ``` + +4. **Reduce-to-one sub-protocol** (§4.5.2, Claim 4.6): after each + sumcheck, the verifier needs to reduce two evaluation claims to one. + This is a separate one-round protocol, NOT part of sumcheck: + + ```rust + /// Reduce claims W(b) = v_0 and W(c) = v_1 to a single claim + /// W(r) = v at a random point on the line through b and c. + pub fn reduce_to_one( + b: &[F], c: &[F], + v0: F, v1: F, + transcript: &mut T, + ) -> (Vec, F) + ``` + + This is a composable building block, not baked into sumcheck. + +## 11. WHIR compatibility + +WHIR's integration pattern: + +```rust +for round_config in &self.round_configs { + // Commit to the current folded polynomial + round_config.committer.commit(&a); + + // OOD / in-domain queries, RLC into covector b + update_covector(&mut b, &stir_challenges); + + // Partial sumcheck: fold a and b by folding_factor variables + let proof = sumcheck( + &mut InnerProductProver::new(a, b), + round_config.folding_factor, + &mut transcript, + |_, t| round_config.round_pow.prove(t), + ); + + // Extract folded state for the next round + a = prover.a(); // prover-specific accessor + b = prover.b(); +} +``` + +Key requirements satisfied: +- **Partial rounds**: `num_rounds = folding_factor < v`. +- **Hook**: proof-of-work grinding between write and read. +- **Post-state**: prover retains folded vectors after partial execution. +- **MSB fold**: `fold()` used independently for WHIR's `multilinear_fold`. + +## 12. Jolt integration + +Jolt (a16z/jolt) uses its own `SumcheckInstanceProver` trait, which +splits the round into two calls: + +```rust +// Jolt's trait (simplified) +trait SumcheckInstanceProver { + fn compute_message(&mut self, round: usize, previous_claim: F) -> UniPoly; + fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize); + fn finalize(&mut self); +} +``` + +Our `SumcheckProver::round(challenge)` merges fold and compute into one +call. An adapter in Jolt's codebase bridges the two: + +```rust +struct Adapter> { + inner: P, + pending: Option, +} + +impl SumcheckInstanceProver for Adapter

{ + fn compute_message(&mut self, _round: usize, _claim: F) -> UniPoly { + let c = self.pending.take(); + UniPoly::from_evals(&self.inner.round(c)) + } + fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.pending = Some(r_j.into()); + } + fn finalize(&mut self) { + if let Some(c) = self.pending.take() { + self.inner.finalize(c); + } + } +} +``` + +### Key compatibility points + +- **Variable ordering**: Jolt defaults to LSB (`BindingOrder::LowToHigh`). + Use `MultilinearProverLSB` for this path. Spartan outer sumcheck uses + MSB (`HighToLow`) — use `MultilinearProver`. + +- **Challenge type**: Jolt uses a narrow 128-bit `F::Challenge` type for + performance. The adapter converts via `Into` at the boundary. + +- **Return type**: Jolt expects `UniPoly` (coefficients). Our trait + returns evaluations at {0, 1, ..., d}. Convert via interpolation or + have the prover return coefficients directly. + +- **Batching**: Jolt's `BatchedSumcheck::prove` combines multiple instances + with random linear combinations. The orchestrator handles batching — + each instance just implements `SumcheckInstanceProver`. + +- **Streaming schedule**: Jolt's `HalfSplitSchedule` (BCFFMMZ25) uses + cost-model-driven window sizes. Our blendy implementation can adopt the + same `StreamingSchedule` trait for stage sizing. + +The adapter lives in Jolt's repository (or a thin integration crate), not +in this library. Our responsibility is keeping the trait surface clean +enough that the adapter is trivial. + +## 13. Advanced optimizations (Bagad-Dao-Domb-Thaler, ePrint 2025/1117) + +The `SumcheckProver` trait cleanly separates *what the protocol sends* +(round polynomials, final evaluation) from *how the prover computes them*. +All known prover-side optimizations live below this boundary and are +compatible without any trait or protocol changes. + +### Algorithms 1--2: LinearTime and SqrtSpace + +These are our "time" and "space" strategies (§4 above). LinearTime +maintains d arrays of size 2^(v−j), folding each round in O(2^v) total. +SqrtSpace streams over the input for the first v/2 rounds, using O(√2^v) +space. Both produce identical round polynomials. Each is a distinct +`SumcheckProver` implementation with different constructors +(`new` vs `from_stream`). + +### Algorithms 3--4: SmallValue optimization + +When sumcheck operates on a product g = f·eq(r, ·), many multiplications +involve "small" operands (those representable with few bits, e.g., R1CS +entries in {0, 1, −1}). The SmallValue optimization categorizes +multiplications as: + +- **ss** (small × small): result fits in a machine word, no field reduction. +- **sl** (small × large): a few shifts and adds, no full mul. +- **ll** (large × large): full field multiplication. + +Algorithms 3--4 defer ll multiplications by precomputing accumulators +and use Toom-Cook to reduce the number of ss multiplications from +3(d+1) to 2.5(d+1) per hypercube point. + +**Trait compatibility**: this is entirely internal to `round()`. The prover +maintains richer internal state (categorized accumulator tables) and uses +cheaper multiplication routines, but returns the same `Vec` of degree+1 +evaluations. A `SmallValueProver` implements `SumcheckProver` with +unchanged `degree()` and `final_value()`. + +### Algorithms 5--6: EqPoly optimization + +For sumcheck over g(x) = eq(r, x) · h(x), the naive approach materializes +a 2^v-sized table for eq(r, ·). The EqPoly optimization splits +eq(r, x) = eq(r_L, x_L) · eq(r_R, x_R) and only maintains tables of size +2^(v/2), reducing ll multiplications from 2^v to 2^(v/2+1). + +Algorithm 6 combines EqPoly with SmallValue for the Spartan-in-Jolt +use case. + +**Trait compatibility**: this changes how the prover manages its internal +tables — specifically, it defers half the eq computation and processes it +in a blocked fashion. The round polynomial's degree and protocol messages +are unchanged. This would be a constructor variant or flag on the +prover type. + +### Univariate skip (Gruen's optimization) + +When g has degree > 1 in its first variable, the prover can compute +a univariate restriction t_j(X_j) = g(r_1, ..., r_{j-1}, X_j, x_{j+1}*, ...) +and derive the standard round polynomial s_j from it. Per Setty and +Thaler (Section 3.1 of the paper): "the sum-check verifier is unchanged." + +This is relevant primarily in small-characteristic fields where the degree +of individual variables can be high. In our setting (Goldilocks, large +characteristic), the multilinear case (degree 1) doesn't benefit, but the +coefficient sumcheck (arbitrary degree d) could. + +**Trait compatibility**: the prover computes the round polynomial via the +univariate-skip shortcut inside `round()`, but returns the same evaluations. +No change to `degree()`, no change to the verifier, no change to the +wire format. + +### Per-round degree variation + +Some optimizations work best when the polynomial's degree varies by round +(e.g., degree d in round 1, degree 1 in rounds 2--v for univariate skip). +Our current trait returns a single `degree()` value. If a future prover +needs per-round variation, the trait can evolve to: + +```rust +fn degree(&self) -> usize; // max degree (for allocation) +fn round_degree(&self) -> usize; // degree of the current round +``` + +None of the paper's algorithms strictly require this — they maintain +constant degree throughout — but this is a natural extension point if +needed. + +### Summary: all optimizations are prover-internal + +| Optimization | Changes protocol? | Changes wire format? | Fits `SumcheckProver`? | +|---------------------------|-------------------|----------------------|------------------------| +| LinearTime (Alg 1) | No | No | Yes — "time" strategy | +| SqrtSpace (Alg 2) | No | No | Yes — "space" strategy | +| SmallValue (Alg 3--4) | No | No | Yes — smarter `round()` | +| EqPoly (Alg 5--6) | No | No | Yes — internal table mgmt | +| Univariate skip | No | No | Yes — alternate `round()` | + +The trait is the abstraction boundary: everything above it (protocol runner, +transcript, verifier) stays fixed; everything below it (prover strategy, +SIMD kernels, table layout, multiplication tricks) is implementation +freedom. + +## 14. What is NOT in scope + +- **Generic IP trait**: sumcheck is a specific protocol, not an instance + of a generic framework. The `Transcript` trait already captures the + interaction pattern. If GKR or FRI reveal common structure, a + shared trait can emerge later. + +- **Zero-knowledge**: ZK sumcheck adds masking polynomials but uses the + same protocol structure. A future `ZkSumcheckProver` wrapper could + add masking without changing the protocol runner. + +- **Batching**: multiple sumcheck instances can be batched with random + linear combinations. This is a composition technique, not a protocol + change. + +- **Reduce-to-one**: a separate composable sub-protocol (§4.5.2), + not part of sumcheck itself. + +## 15. Migration from current API + +| Current | New | +|-------------------------------------------|------------------------------------------------| +| `multilinear_sumcheck(evals, t, hook)` | `sumcheck(&mut MultilinearProver::new(evals), n, t, hook)` | +| `inner_product_sumcheck(a, b, t, hook)` | `sumcheck(&mut InnerProductProver::new(a, b), n, t, hook)` | +| `multilinear_sumcheck_partial(..., k, h)` | `sumcheck(&mut prover, k, t, hook)` | +| `fold(values, weight)` | `fold(values, weight)` (unchanged) | +| `coefficient_sumcheck(...)` | `sumcheck(&mut CoefficientProver::new(...), n, t, hook)` | +| `Sumcheck` | `SumcheckProof` | +| `ProductSumcheck` | `SumcheckProof` (unified) | +| `multilinear_sumcheck_verify(...)` | `sumcheck_verify(sum, deg, n, t, hook)` | + +The `Sumcheck` and `ProductSumcheck` return types unify into one +`SumcheckProof`. The `final_evaluations: (F, F)` field from +`ProductSumcheck` becomes a prover-specific accessor on +`InnerProductProver` post-state. + +## 16. Benchmarking + +### Three benchmark layers + +**Layer 1: Kernel throughput (elements/second).** Measures the computational +core — fold and round polynomial evaluation — independent of the protocol. + +- `fold` throughput: elements/sec for base field and each extension degree, + with and without SIMD. Compare against theoretical memory bandwidth + (Apple M-series: ~100 GB/s, Sapphire Rapids: ~50 GB/s per channel). + A Goldilocks element is 8 bytes, so the ceiling for base-field fold on + M2 is ~12.5 billion elements/sec. +- `round()` throughput: fold + evaluate combined. Shows the overhead of + evaluation on top of fold. +- Cross-field promotion cost: measure the first EF round (BF→EF lift) + separately. For ext3, memory triples and arithmetic cost jumps ~9×. + +**Layer 2: Protocol-level scaling (time vs num_variables).** Full +`sumcheck()` execution — all v rounds — for each combination of: + +- Strategy: time, blendy(k) for several k values +- Shape: multilinear (d=1), inner-product (d=2) +- Field: base only, base→ext3 +- Size: v = 16, 18, 20, 22, 24 (65K to 16M evaluations) + +Plot time vs 2^v on log-log axes. Time strategy should be linear (slope 1). +Blendy should show a constant-factor gap vs time — the benchmark should +quantify this gap at each k to guide users on the space/time tradeoff. + +The key metric is **time per element per round**, not total time. This +normalizes across sizes and makes regressions immediately visible. + +**Layer 3: Downstream integration.** Run WHIR's sumcheck bench with our +library as the backend. The metric is **regression detection**: pin a +baseline, alert on >5% regression. + +### CI regression tracking + +Continuous benchmarks via `github-action-benchmark`: + +- **criterion** bench harness with throughput annotations (elements/sec) +- Self-hosted runner on EC2 Sapphire Rapids for stable AVX-512 numbers +- Results committed to `gh-pages` branch with per-benchmark trendlines +- Alert threshold: 10% regression = warning, 15% = CI failure + +Benchmark matrix: + +``` +{time} × {multilinear, IP} × {F64, F64Ext3} × {2^16, 2^20, 2^24} +fold × {F64, F64Ext3} × {2^16, 2^20, 2^24} +``` + +~20 benchmark points, ~5 minutes on AVX-512 hardware. + +## 17. Summary + +The design follows Thaler's formalization exactly: + +- **One protocol** (Proposition 4.1), parameterized by the polynomial. +- **One trait** (`SumcheckProver`) for the prover's round computation. +- **Three strategies** (time/space/blendy) as construction choices. +- **Three polynomial shapes** (multilinear/product/general) as prover types. +- **One fold** (Lemma 4.3) as the core computational primitive. +- **One verifier** that checks any proof regardless of degree. +- **Partial execution** for protocol composition (GKR, WHIR). +- **Post-state inspection** via `&mut P` ownership for protocol chaining. + +The protocol runner, verifier, and fold are the public API. The prover +trait is the extension point. SIMD acceleration is transparent inside +fold. Everything else is internal. diff --git a/docs/migration.md b/docs/migration.md new file mode 100644 index 00000000..967ee4d9 --- /dev/null +++ b/docs/migration.md @@ -0,0 +1,91 @@ +# Migrating from arkworks-rs/sumcheck + +If you're using [arkworks-rs/sumcheck](https://github.com/arkworks-rs/sumcheck), +here's how to accomplish the same things with effsc. + +## If you're using `ListOfProductsOfPolynomials` + +In arkworks, you describe a sumcheck claim as a sum of products of multilinear +extensions and let the library evaluate it generically. In effsc, you implement +`RoundPolyEvaluator` — writing the polynomial logic directly. The library +handles folding, parallelism, and SIMD. + +For example, given $H = \sum_{x} [ f(x) \cdot g(x) + h(x) \cdot k(x) ]$: + +```rust +use effsc::coefficient_sumcheck::RoundPolyEvaluator; + +struct MyEvaluator; + +impl RoundPolyEvaluator for MyEvaluator { + fn degree(&self) -> usize { 2 } + + fn accumulate_pair( + &self, + coeffs: &mut [F], + _tw: &[(&[F], &[F])], + pw: &[(F, F)], + ) { + let (f0, f1) = pw[0]; + let (g0, g1) = pw[1]; + let (h0, h1) = pw[2]; + let (k0, k1) = pw[3]; + + coeffs[0] += f0 * g0 + h0 * k0; + let at_1 = f1 * g1 + h1 * k1; + coeffs[1] += at_1 - coeffs[0]; + let f2 = f0 + (f1 - f0) + (f1 - f0); + let g2 = g0 + (g1 - g0) + (g1 - g0); + let h2 = h0 + (h1 - h0) + (h1 - h0); + let k2 = k0 + (k1 - k0) + (k1 - k0); + coeffs[2] += f2 * g2 + h2 * k2; + } +} +``` + +Then wire it through `CoefficientProver`: + +```rust +use effsc::provers::coefficient::CoefficientProver; + +let mut prover = CoefficientProver::new( + &MyEvaluator, + &[], + &mut [f_evals, g_evals, h_evals, k_evals], +); +``` + +The evaluator reads as a direct translation of the protocol math, and because +you write the polynomial logic yourself, you can exploit structure like shared +subexpressions or known-constant factors. + +## If you're using `GKRRoundSumcheck` + +In arkworks, `GKRRoundSumcheck` wraps `MLSumcheck` for the GKR round +polynomial. In effsc, use `GkrProver`: + +```rust +use effsc::provers::gkr::GkrProver; + +// add_evals, mult_evals: gate predicates over {0,1}^{2k}, +// partially evaluated at the previous layer's random point. +// w_evals: witness W_{i+1} over {0,1}^k. +let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); +let proof = sumcheck( + &mut prover, + 2 * k, + &mut transcript, + noop_hook, +); + +// Extract claimed W values for the reduce-to-one sub-protocol. +let (w_b_star, w_c_star) = prover.claimed_w_values(); +``` + +A full GKR verification loops one sumcheck per circuit layer, feeding each +layer's `(w_b_star, w_c_star)` into the reduce-to-one sub-protocol and then +into the next layer's sumcheck. + +The current `GkrProver` has the same O(2^{2k}) complexity per layer as +arkworks-rs/sumcheck. An optimized O(2^k · k) version using incremental +eq-polynomial bookkeeping is possible but not yet implemented. diff --git a/docs/slides.md b/docs/slides.md new file mode 100644 index 00000000..7a1bcfee --- /dev/null +++ b/docs/slides.md @@ -0,0 +1,633 @@ +--- +marp: true +theme: default +paginate: true +size: 16:9 +style: | + section { + font-size: 24px; + } + h1 { + font-size: 40px; + color: #1a1a2e; + } + h2 { + font-size: 32px; + color: #16213e; + } + h3 { + font-size: 26px; + } + table { + font-size: 20px; + } + code { + font-size: 18px; + } + pre { + font-size: 16px; + } +--- + +# Efficient Sumcheck + +A sumcheck library built for Arkworks grounded in textbook formalization. + +--- + +## The Problem + +The sumcheck protocol is **one protocol** parameterized by the polynomial shape. +The old library treated each shape as a separate implementation: + +- 3 protocol runners, 2 verifiers, 2 return types +- 10+ public entry points with inconsistent signatures +- Adding a new polynomial shape meant duplicating the entire stack + +This made integration hard (WHIR needed hooks, GKR needed partial execution, +Jolt needed a different variable ordering) and every new feature touched +every runner. + +**10,800 LOC. The code was correct. The architecture was not.** + +--- + +## The Goal + +Reduce complexity without losing functionality. + +| Metric | Before | After | +|--------|--------|-------| +| Lines of code | 10,800 | 7,448 | +| Source files | ~100 | 43 | +| Public entry points | 10+ | 4 | +| Order strategies | 4 | 1 (MSB) | +| Protocol runners | 3 | 1 | +| Verifiers | 2 | 1 | +| Return types | 2 | 1 | +| Tests | 69 | 63 | +| Clippy warnings | 46 | 0 | + +--- + +## The Authoritative Source + +Justin Thaler, *Proofs, Arguments, and Zero-Knowledge*, Chapter 4. + +**Proposition 4.1.** Given a v-variate polynomial g over F with degree +at most d in each variable, the sum-check protocol proves +`H = sum_{b in {0,1}^v} g(b)` in v rounds. + +- Completeness error: 0 +- Soundness error: <= v * d / |F| + +**Key insight:** the protocol is *one* protocol parameterized by g. +Three "different" sumchecks are three instantiations. + +--- + +## One Protocol, Three Instantiations + +| Use case | g | Degree | Reference | +|----------|---|--------|-----------| +| Multilinear | f_tilde (MLE) | 1 | Thaler S4.1 | +| Inner product | f_tilde * g_tilde | 2 | Thaler S4.4 | +| Coefficient | user-defined | d | Thaler S4.6 | + +The protocol (transcript, consistency checks, challenges) is identical. + +Only the prover's round polynomial computation changes. + +This motivates **one runner + one trait**, not three functions. + +--- + +## The Prover Trait + +```rust +pub trait SumcheckProver { + fn degree(&self) -> usize; + fn round(&mut self, challenge: Option) -> Vec; + fn finalize(&mut self, last_challenge: F); + fn final_value(&self) -> F; +} +``` + +**Lifecycle:** + +``` +round(None) -> g_0 evaluations // round 0 +round(Some(r_0)) -> g_1 evaluations // fold with r_0, compute g_1 +... +round(Some(r_{v-2})) -> g_{v-1} evaluations // fold, compute +finalize(r_{v-1}) // apply last challenge +final_value() -> g(r_0, ..., r_{v-1}) // oracle value +``` + +--- + +## The Protocol Runner + +```rust +pub fn sumcheck>( + prover: &mut impl SumcheckProver, + num_rounds: usize, + transcript: &mut T, + hook: impl FnMut(usize, &mut T), +) -> SumcheckProof +``` + +One function handles: +- Full sumcheck (num_rounds = v) +- Partial sumcheck (num_rounds < v) for GKR and WHIR +- Per-round hooks for proof-of-work grinding +- Any polynomial degree (degree is prover-reported) + +**The runner never inspects prover internals.** + +--- + +## The Verifier + +```rust +pub fn sumcheck_verify>( + claimed_sum: F, + expected_degree: usize, + num_rounds: usize, + transcript: &mut T, + hook: impl FnMut(usize, &mut T) -> Result<(), SumcheckError>, +) -> Result, SumcheckError> +``` + +- Checks g_j(0) + g_j(1) = claim each round +- Evaluates g_j(r_j) via Lagrange interpolation (any degree) +- Returns `SumcheckResult { challenges, final_claim }` + +The verifier doesn't know g ([Thaler Remark 4.2](https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf)), so the oracle check +is the caller's responsibility. The API returns `final_claim` directly — +the caller handles it according to their protocol: + +```rust +assert_eq!(result.final_claim, proof.final_value) // standalone +next_claim = result.final_claim; // WHIR, GKR (next layer) +``` + +--- + +## Unified Proof Type + +```rust +pub struct SumcheckProof { + pub round_polys: Vec>, // g_j at {0,1,...,d} + pub challenges: Vec, // r_1, ..., r_v + pub final_value: F, // g(r_1, ..., r_v) +} +``` + +Replaces both `Sumcheck` and `ProductSumcheck`. + +Prover-specific post-state (e.g., `(f(r), g(r))` for inner product) +lives on the prover via `&mut P` ownership, not in the proof. + +--- + +## Concrete Provers + +| Prover | Degree | Post-state | +|--------|--------|------------| +| `MultilinearProver` | 1 | — | +| `InnerProductProver` | 2 | `final_evaluations() -> (F, F)` | +| `CoefficientProver` | d | — | +| `GkrProver` | 2 | `claimed_w_values() -> (F, F)` | + +Each has MSB and LSB variants (except GkrProver: MSB only). + +Same runner. Same verifier. Same proof type. + +--- + +## The Fold Primitive (Lemma 4.3) + +``` +new[k] = v[k] + weight * (v[k + L/2] - v[k]) +``` + +- Half-split (MSB) layout: fold the topmost variable each round +- Matches Thaler eq. 4.13 directly +- Non-power-of-two: implicit zero padding on the high half +- SIMD-accelerated for Goldilocks (transparent, zero overhead otherwise) +- Exposed publicly for WHIR's `multilinear_fold` + +--- + +## SIMD Acceleration + +Goldilocks field (p = 2^64 - 2^32 + 1): + +| Backend | Width | Platform | +|---------|-------|----------| +| NEON | 2-wide | aarch64 (Apple M-series, Graviton) | +| AVX-512 IFMA | 8-wide | x86_64 (Sapphire Rapids) | + +Two paths to SIMD: + +**Arkworks** (automatic): blanket impl detects Goldilocks from modulus. +LLVM const-folds the branch. Zero overhead on non-Goldilocks. + +**Non-arkworks** (explicit): implement `SimdRepr` with `zerocopy` bounds: + +```rust +pub trait SimdRepr: + SumcheckField + zerocopy::IntoBytes + zerocopy::FromBytes +{ + fn modulus() -> u64; // GOLDILOCKS_P for SIMD +} +``` + +Layout safety is **compiler-verified** via zerocopy derives. No `unsafe`. + +--- + +## Generic Field Trait + +```rust +pub trait SumcheckField: + Copy + Send + Sync + PartialEq + Debug + + Add + Sub + Mul + Neg + + AddAssign + SubAssign + MulAssign + + Sum + 'static +{ + const ZERO: Self; + const ONE: Self; + fn from_u64(val: u64) -> Self; + fn inverse(&self) -> Option; + fn extension_degree() -> u64; + fn _simd_field_config() -> Option; +} +``` + +- **Not coupled to arkworks.** Any field-like type works. +- Blanket impl for `ark_ff::Field` behind `feature = "arkworks"` (default-on). +- Non-arkworks Goldilocks can opt into SIMD by overriding `_simd_field_config()`. + +--- + +## Cross-Field Support (BF -> EF) *(TODO)* + +Evaluations in base field BF (e.g., Goldilocks). +Challenges from extension field EF (e.g., Goldilocks^3) for soundness. + +```rust +pub trait ExtensionOf: SumcheckField + From {} +``` + +The transition is prover-internal: +- Round 0: compute over BF +- `round(Some(r_1))`: lift BF -> EF, continue in EF +- Protocol runner and verifier never see BF + +--- + +## WHIR Integration ([PR](https://github.com/WizardOfMenlo/whir/pull/250)) + +```rust +for round_config in &self.round_configs { + round_config.committer.commit(&a); + update_covector(&mut b, &stir_challenges); + + let proof = sumcheck( + &mut InnerProductProver::new(a, b), + round_config.folding_factor, // partial! + &mut transcript, + |_, t| round_config.round_pow.prove(t), // hook! + ); + + a = prover.a(); // post-state access + b = prover.b(); +} +``` + +Three features exercised: partial execution, per-round hook, post-state. + +--- + +## GKR Integration + +`GkrProver` implements `SumcheckProver` for the GKR round polynomial: + +```text +f_r(b, c) = add_i(r, b, c) · (W(b) + W(c)) + mult_i(r, b, c) · (W(b) · W(c)) +``` + +```rust +let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); +let proof = sumcheck(&mut prover, 2 * k, &mut t, noop_hook); +let (w_b, w_c) = prover.claimed_w_values(); // for reduce-to-one +``` + +Same runner, same verifier, same proof type. GKR is just another prover. + +Reduce-to-one is a separate composable sub-protocol (Thaler §4.5.2). + +--- + +## Two Orthogonal Axes + +Prover design has two independent choices: + +**Space strategy** -- how much memory to budget: +- Time: O(2^v) -- hold all evaluations +- Blendy: O(2^k) -- partition into stages, recompute per stage +- Space: O(v) -- recompute everything (academic only) + +**Variable ordering** -- which variable to fold each round: +- MSB (half-split): pairs `(v[k], v[k+L/2])` -- in-memory and seekable streams +- LSB (pair-split): pairs `(v[2k], v[2k+1])` -- sequential/incremental streams + +These are orthogonal. Blendy + MSB and Blendy + LSB are both valid. + +--- + +## Streaming Taxonomy + +| Scenario | Data | Access | Ordering | Example | +|----------|------|--------|----------|---------| +| In-memory | Full table in RAM | Random | MSB | WHIR | +| Random-access stream | On disk, too big for RAM | Seekable | MSB | Large witness (mmap) | +| Sequential stream | Generated incrementally | Forward-only | LSB | Jolt CPU trace | + +**Random-access** (mmap'd SSD): data exists but doesn't fit in RAM. +MSB reads two contiguous half-table regions -- good cache behavior. + +**Sequential** (Jolt trace): evaluations arrive in index order. +LSB pairs `(f[2k], f[2k+1])` are immediately available -- +folding begins before the full table exists. + +Both streaming cases use blendy. The ordering choice depends on the data source. + +--- + +## Blendy Stage Scheduling (BCFFMMZ25) *(TODO)* + +Jolt's `HalfSplitSchedule` uses **cost-model-driven, non-uniform windows**: + +``` +w(i) = round(ratio * i) where ratio = ln(2) / ln((d+1)/2) +``` + +Windows grow with round number: early rounds (large hypercube) get small +windows; later rounds (small residual) get large windows. + +| Degree | Ratio | Window sequence | +|--------|-------|-----------------| +| 2 | 1.71 | 1, 2, 5, 14, ... | +| 3 | 1.00 | 1, 1, 2, 3, 4, ... | +| 4 | 0.76 | 1, 1, 1, 2, 2, 3, ... | + +**Two-phase structure:** +1. Streaming phase (first half): cost-optimal windows, one trace pass per window +2. Linear phase (second half): materialized mode, every round is its own window + +Parameterized by `StreamingSchedule` trait -- not a fixed constant. + +Based on BCFFMMZ25 (eprint 2025/1473): O(kN) time, O(N^{1/k}) space. + +--- + +## Jolt Compatibility *(description of possible integration)* + +Jolt's `SumcheckInstanceProver` trait: + +```rust +fn compute_message(&mut self, round: usize, claim: F) -> UniPoly; +fn ingest_challenge(&mut self, r_j: F::Challenge, round: usize); +fn finalize(&mut self); +``` + +Our `SumcheckProver` maps cleanly via adapter: + +```rust +fn compute_message(&mut self, round: usize, _claim: F) -> UniPoly { + let challenge = self.pending.take().map(Into::into); + UniPoly::from_evals(&self.inner.round(challenge)) +} +fn ingest_challenge(&mut self, r_j: F::Challenge, _round: usize) { + self.pending = Some(r_j); +} +``` + +LSB multilinear prover + this adapter = drop-in replacement for Jolt. + +--- + +## What We Deleted + +| Module | LOC | Why | +|--------|-----|-----| +| multilinear/provers/ | 1,190 | Old Prover trait scaffolding | +| multilinear_product/provers/ | 1,175 | Same | +| order_strategy/ | 315 | 4 strategies -> MSB only | +| SIMD dispatch (product paths) | ~500 | Superseded by fused MSB kernels | +| messages/, interpolation/ | 327 | Unused / graycode-coupled | +| Old test harness | ~800 | Legacy prover tests | +| prover/core.rs | 24 | Old trait definition | + +**Total: 4,518 lines deleted across 86 files.** + +--- + +## Advanced Optimizations Fit the Trait + +From Bagad-Dao-Domb-Thaler, [*Sumcheck Optimizations*](https://eprint.iacr.org/2025/1117) (ePrint 2025/1117): + +| Optimization | Changes protocol? | Changes wire format? | +|--------------|-------------------|---------------------| +| LinearTime (Alg 1) | No | No | +| SqrtSpace (Alg 2) | No | No | +| SmallValue (Alg 3-4) | No | No | +| EqPoly (Alg 5-6) | No | No | +| Univariate skip | No | No | + +**All optimizations live below the trait boundary.** + +The protocol runner, verifier, and transcript are untouched. +Each optimization is a different `SumcheckProver` implementation. + +--- + +## Feature Gate Architecture + +```toml +[features] +default = ["arkworks", "parallel", "simd"] + +arkworks = ["ark-ff", "ark-poly", "ark-serialize", + "ark-std", "spongefish"] +parallel = ["rayon", "ark-ff?/parallel", ...] +simd = [] +``` + +- `--no-default-features`: pure `SumcheckField` library, no arkworks, no SIMD +- `--features arkworks`: blanket impl for `ark_ff::Field` +- `--features parallel`: rayon parallelism for fold and round computation +- `--features simd`: SIMD backends (NEON, AVX-512 IFMA) for Goldilocks +- SIMD dispatch is const-folded by LLVM -- zero overhead when field doesn't match + +--- + +## Benchmarking Strategy — TODO + +### Layer 1: Kernel Throughput +- fold elements/sec vs memory bandwidth ceiling +- round() overhead on top of fold +- BF->EF promotion cliff + +### Layer 2: Protocol Scaling +- Full sumcheck: time vs 2^v on log-log axes +- Matrix: {time, blendy} x {ml, IP} x {F64, F64Ext3} x {2^16..2^24} +- Key metric: **time per element per round** + +### Layer 3: Downstream Integration +- WHIR sumcheck bench as acceptance gate +- Regression detection, not absolute performance + +--- + +## CI Benchmark Infrastructure — TODO + +```yaml +# .github/workflows/bench.yml +- name: Detect CPU features + run: | + if grep -q avx512ifma /proc/cpuinfo; then + echo "rustflags=-C target-feature=+avx512ifma" + fi + +- name: Run benchmarks + run: cargo bench --bench sumcheck -- --output-format bencher + +- uses: benchmark-action/github-action-benchmark@v1 + with: + alert-threshold: "115%" + auto-push: true # on main +``` + +- Auto-detects AVX-512 IFMA / NEON / scalar +- Results to gh-pages with trendlines +- 15% regression = alert + PR comment + +--- + +## Current Benchmark Matrix + +18 benchmark points across 6 groups: + +``` +multilinear/F64: 2^16, 2^20, 2^24 +multilinear/F64Ext3: 2^16, 2^20, 2^24 +inner_product/F64: 2^16, 2^20, 2^24 +inner_product/F64Ext3: 2^16, 2^20, 2^24 +fold/F64: 2^16, 2^20, 2^24 +fold/F64Ext3: 2^16, 2^20, 2^24 +``` + +Criterion harness with `Throughput::Elements(n)` annotations. +~5 minutes on AVX-512 hardware. + +--- + +## Migration Table (internal) + +| Old effsc | New effsc | +|-----|-----| +| `multilinear_sumcheck(evals, t, hook)` | `sumcheck(&mut MultilinearProver::new(evals), v, t, hook)` | +| `inner_product_sumcheck(a, b, t, hook)` | `sumcheck(&mut InnerProductProver::new(a, b), v, t, hook)` | +| `multilinear_sumcheck_partial(...)` | `sumcheck(&mut prover, k, t, hook)` | +| `fold(values, weight)` | `fold(values, weight)` (unchanged) | +| `Sumcheck` | `SumcheckProof` | +| `ProductSumcheck` | `SumcheckProof` (unified) | + +--- + +## Migration from arkworks-rs/sumcheck + +| arkworks-rs/sumcheck | effsc | +|-----|-----| +| `ListOfProductsOfPolynomials` | Custom `RoundPolyEvaluator` + `CoefficientProver` | +| `GKRRoundSumcheck::prove(...)` | `sumcheck(&mut GkrProver::new(add, mult, w), ...)` | +| `GKRRoundSumcheck::verify(...)` | `sumcheck_verify(sum, 2, rounds, t, hook)` | + +See [`docs/migration.md`](migration.md) for worked examples. + +--- + +## What's NOT in Scope + +Explicit non-goals (from Thaler's framing): + +- **Generic IP trait** -- sumcheck is specific, not an instance of a framework +- **Zero-knowledge** -- future `ZkSumcheckProver` wrapper (masking polynomials) +- **Batching** -- compose via random linear combinations externally +- **Reduce-to-one** -- separate sub-protocol (S4.5.2), not part of sumcheck + +These can be added later without changing the core trait or runner. + +--- + +## Design Principles + +1. **One protocol, one trait, many implementations.** + Three polynomial shapes are three `SumcheckProver` impls, not three runners. + +2. **The trait boundary is the optimization boundary.** + Everything above it (runner, verifier, transcript) is fixed. + Everything below it (fold, SIMD, table layout, multiplication tricks) is freedom. + +3. **Partial execution is first-class.** + `num_rounds < v` enables GKR and WHIR without special-casing. + +4. **Post-state via ownership, not return types.** + `&mut P` survives sumcheck; prover-specific accessors are type-safe. + +5. **The verifier returns the final claim, not a verdict.** + `sumcheck_verify` returns `SumcheckResult { challenges, final_claim }`. + The oracle check is the caller's concern — standalone callers compare, + composed callers (WHIR, GKR) pass `final_claim` to the next layer. + +6. **Features are orthogonal layers.** + `arkworks`, `parallel`, and `simd` can be enabled independently. + The core library works with any `SumcheckField` and zero dependencies. + +--- + +## Next Steps + +| Item | Status | Notes | +|------|--------|-------| +| CoefficientProver | Done | MSB + LSB variants | +| GkrProver | Done | Reference impl, O(2^{2k}) | +| WHIR integration | Done | [PR #250](https://github.com/WizardOfMenlo/whir/pull/250) | +| SECURITY.md | Done | Threat model, unsafe scope, disclosure policy | +| ark-sumcheck migration guide | Done | docs/migration.md | +| GkrProver O(2^k · k) optimization | Future | Incremental eq-polynomial bookkeeping | +| Blendy prover | Deferred | Pending LSB vs MSB investigation (Jolt) | +| Jolt adapter | Future | Drop-in `SumcheckInstanceProver` impl | +| StreamingSchedule trait | Investigating | Cost-model windows (BCFFMMZ25) | +| Additional field support | Future | M31, BabyBear, KoalaBear | + +--- + +## Summary + +**The sum-check protocol is one protocol.** + +We made the code match that fact. + +- 1 runner, 1 verifier, 1 proof type, 1 fold +- 7 concrete provers (multilinear, inner-product, coefficient, GKR × MSB/LSB) +- Generic over any field (not arkworks-specific) +- SIMD transparent for Goldilocks (AVX-512 IFMA, NEON) +- Integrated into [WHIR](https://github.com/WizardOfMenlo/whir/pull/250) and [WARP](https://github.com/compsec-epfl/warp/pull/24) with measured performance improvements +- Superset of arkworks-rs/sumcheck functionality (see docs/migration.md) +- Correctness-fuzzed against a formally verified oracle +- All known optimizations fit below the trait boundary diff --git a/docs/slides.pdf b/docs/slides.pdf new file mode 100644 index 00000000..d1d0e43c Binary files /dev/null and b/docs/slides.pdf differ diff --git a/docs/slides.pptx b/docs/slides.pptx new file mode 100644 index 00000000..5e01c88e Binary files /dev/null and b/docs/slides.pptx differ diff --git a/src/coefficient_sumcheck.rs b/src/coefficient_sumcheck.rs index 056a7a85..d7115391 100644 --- a/src/coefficient_sumcheck.rs +++ b/src/coefficient_sumcheck.rs @@ -1,8 +1,11 @@ use ark_ff::Field; -use ark_poly::{univariate::DensePolynomial, Polynomial}; +use ark_poly::univariate::DensePolynomial; -use crate::multilinear::reductions::{pairwise, tablewise}; -use crate::transcript::Transcript; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use crate::reductions::{pairwise, tablewise}; +use crate::transcript::ProverTranscript; #[derive(Debug)] pub struct CoefficientSumcheck { @@ -10,37 +13,341 @@ pub struct CoefficientSumcheck { pub verifier_messages: Vec, } +/// Trait for computing the round polynomial from a single pair of rows. +/// +/// The library iterates over pairs (even/odd rows from each table), +/// calls [`accumulate_pair`](RoundPolyEvaluator::accumulate_pair) for each, +/// which adds the pair's contribution directly into a shared coefficient buffer. +/// This avoids per-pair polynomial allocation — the library owns the buffer. +/// +/// # Arguments to `accumulate_pair` +/// +/// - `coeffs`: mutable slice of length [`degree`](RoundPolyEvaluator::degree)`+ 1`. +/// The evaluator **adds** its contribution into these coefficients (do NOT zero them). +/// - `tablewise_pairs`: one `(even_row, odd_row)` slice-pair per tablewise table +/// - `pairwise_pairs`: one `(even_elem, odd_elem)` pair per pairwise table +/// +/// # Example +/// +/// ```text +/// struct MyEvaluator; +/// impl RoundPolyEvaluator for MyEvaluator { +/// fn degree(&self) -> usize { 1 } +/// +/// fn accumulate_pair( +/// &self, +/// coeffs: &mut [F], +/// tw: &[(&[F], &[F])], +/// pw: &[(F, F)], +/// ) { +/// let (even, odd) = pw[0]; +/// coeffs[0] += even; // constant coefficient +/// coeffs[1] += odd - even; // linear coefficient +/// } +/// } +/// ``` +pub trait RoundPolyEvaluator: Sync { + /// The degree of the round polynomial (number of coefficients = degree + 1). + fn degree(&self) -> usize; + + /// Accumulate this pair's contribution into `coeffs[0..=degree]`. + /// + /// `coeffs` is pre-zeroed at the start of each round. The evaluator + /// should **add** (not assign) its contribution. + fn accumulate_pair( + &self, + coeffs: &mut [F], + tablewise_pairs: &[(&[F], &[F])], + pairwise_pairs: &[(F, F)], + ); + + /// Hint: is the per-pair work heavy enough to benefit from rayon parallelism? + /// + /// Return `true` for evaluators that do substantial work per pair (polynomial + /// multiplication, R1CS evaluation, etc.). Return `false` for trivial + /// evaluators (simple sums, single multiply) where rayon overhead dominates. + /// + /// Default: `true` (assume heavy — safe default since rayon's overhead is + /// small relative to the work for most real use cases). + fn parallelize(&self) -> bool { + true + } +} + +// ── Evaluate strategies ───────────────────────────────────────────────────── + +/// SIMD fast path for degree-1 with a single pairwise table. +/// +/// Returns `[sum_even, sum_odd - sum_even]` = coefficients of `h(x) = c0 + c1*x`. +fn simd_evaluate_degree1(pw: &[F]) -> Vec { + // Try SIMD dispatch for Goldilocks + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + { + if let Some(coeffs) = try_simd_evaluate_degree1(pw) { + return coeffs; + } + } + + // Generic fallback + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for chunk in pw.chunks_exact(2) { + s0 += chunk[0]; + s1 += chunk[1]; + } + vec![s0, s1 - s0] +} + +/// SIMD implementation of degree-1 evaluate. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +fn try_simd_evaluate_degree1(pw: &[F]) -> Option> { + crate::simd_sumcheck::dispatch::try_simd_evaluate_degree1(pw) +} + +/// Fused SIMD reduce + degree-1 evaluate for next round. +/// +/// Returns `Some([s0, s1 - s0])` if SIMD dispatch succeeded (reduces in-place +/// and computes next round's coefficients). Returns `None` to fall back to +/// separate reduce + evaluate. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +fn try_simd_fused_reduce_evaluate(pw: &mut Vec, challenge: F) -> Option> { + crate::simd_sumcheck::dispatch::try_simd_fused_reduce_evaluate_degree1(pw, challenge) +} + +#[cfg(not(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +)))] +fn try_simd_fused_reduce_evaluate(_pw: &mut Vec, _challenge: F) -> Option> { + None +} + +/// Parallel evaluate using rayon (for heavy evaluators). +#[cfg(feature = "parallel")] +fn parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + n_coeffs: usize, +) -> Vec { + let accumulate_at = |coeffs: &mut [F], pair_idx: usize| { + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + debug_assert!(n_tw <= 16 && n_pw <= 16); + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[2 * pair_idx], &table[2 * pair_idx + 1]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[2 * pair_idx], table[2 * pair_idx + 1]); + } + evaluator.accumulate_pair(coeffs, &tw_buf[..n_tw], &pw_buf[..n_pw]); + }; + + (0..n_pairs) + .into_par_iter() + .fold_with(vec![F::ZERO; n_coeffs], |mut acc, pair_idx| { + accumulate_at(&mut acc, pair_idx); + acc + }) + .reduce_with(|mut a, b| { + for (ai, bi) in a.iter_mut().zip(&b) { + *ai += *bi; + } + a + }) + .unwrap_or_else(|| vec![F::ZERO; n_coeffs]) +} + +/// Fallback when parallel feature is disabled. +#[cfg(not(feature = "parallel"))] +fn parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + n_coeffs: usize, +) -> Vec { + let mut coeffs = vec![F::ZERO; n_coeffs]; + sequential_evaluate_into( + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + n_pairs, + &mut coeffs, + ); + coeffs +} + +/// Sequential evaluate (for trivial evaluators where rayon overhead dominates). +/// +/// Fills `coeffs_out` with accumulated coefficients (zeroes it first). +fn sequential_evaluate_into( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + coeffs_out: &mut [F], +) { + for c in coeffs_out.iter_mut() { + *c = F::ZERO; + } + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + debug_assert!(n_tw <= 16 && n_pw <= 16); + + for pair_idx in 0..n_pairs { + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[2 * pair_idx], &table[2 * pair_idx + 1]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[2 * pair_idx], table[2 * pair_idx + 1]); + } + evaluator.accumulate_pair(coeffs_out, &tw_buf[..n_tw], &pw_buf[..n_pw]); + } +} + /// Sumcheck prover for arbitrary-degree round polynomials in coefficient form. /// -/// Each round: `compute_round_poly` produces the round polynomial → coefficients -/// are sent to the transcript → challenge is received → all tables are reduced. +/// The user provides a [`RoundPolyEvaluator`] that computes the round polynomial +/// contribution for a single pair. The library handles: +/// - Parallel iteration over pairs (via rayon when `parallel` is enabled) +/// - Summation of per-pair polynomials +/// - Transcript interaction (d-coefficient optimization: leading coefficient omitted) +/// - SIMD-accelerated pairwise reduce (auto-dispatched for Goldilocks) +/// - Tablewise reduce pub fn coefficient_sumcheck( - mut compute_round_poly: impl FnMut(&[Vec>], &[Vec]) -> DensePolynomial, + evaluator: &impl RoundPolyEvaluator, tablewise: &mut [Vec>], pairwise: &mut [Vec], n_rounds: usize, - transcript: &mut impl Transcript, + transcript: &mut impl ProverTranscript, ) -> CoefficientSumcheck { let mut prover_messages = Vec::with_capacity(n_rounds); let mut verifier_messages = Vec::with_capacity(n_rounds); - for _ in 0..n_rounds { - let round_poly = compute_round_poly(tablewise, pairwise); - - for coeff in &round_poly.coeffs { - transcript.write(*coeff); + let n_tw = tablewise.len(); + let n_pw = pairwise.len(); + let deg = evaluator.degree(); + let n_coeffs = deg + 1; + + let use_parallel = evaluator.parallelize(); + let is_degree1_simd_path = deg == 1 && n_pw == 1 && n_tw == 0; + + let mut pending_degree1_eval: Option> = None; + + // Pre-allocate coefficient buffer — reused across rounds for sequential path. + let mut coeffs_buf = vec![F::ZERO; n_coeffs]; + + for round in 0..n_rounds { + let n_pairs = if n_tw > 0 { + tablewise[0].len() / 2 + } else if n_pw > 0 { + pairwise[0].len() / 2 + } else { + 0 + }; + + // ── Evaluate: build round polynomial coefficients ── + // + // Three strategies in order of preference: + // 1. SIMD fast path: degree-1, single pairwise table, no tablewise → + // use evaluate_parallel or fused reduce+evaluate + // 2. Parallel: heavy evaluator → rayon fold_with across pairs + // 3. Sequential: trivial evaluator → simple loop, no rayon overhead + let coeffs = if let Some(cached) = pending_degree1_eval.take() { + cached + } else if is_degree1_simd_path { + simd_evaluate_degree1::(&pairwise[0]) + } else if use_parallel { + parallel_evaluate( + evaluator, tablewise, pairwise, n_tw, n_pw, n_pairs, n_coeffs, + ) + } else { + // Fill pre-allocated buffer (no allocation), then clone the + // small coefficient vec (d+1 elements, typically 2-3). + sequential_evaluate_into( + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + n_pairs, + &mut coeffs_buf, + ); + coeffs_buf.clone() + }; + + let round_poly = DensePolynomial { coeffs }; + + // Send only the first d coefficients (omit the leading one). + let d = round_poly.coeffs.len().saturating_sub(1); + for coeff in &round_poly.coeffs[..d] { + transcript.send(*coeff); } prover_messages.push(round_poly); - let c = transcript.read(); + let c = transcript.challenge(); verifier_messages.push(c); + // ── Reduce ── for table in tablewise.iter_mut() { tablewise::reduce_evaluations(table, c); } - for table in pairwise.iter_mut() { - pairwise::reduce_evaluations(table, c); + + if is_degree1_simd_path && round < n_rounds - 1 { + // Fused reduce+evaluate: SIMD reduce in-place and compute + // next round's (s0, s1) in one pass when possible. + if let Some(next_coeffs) = try_simd_fused_reduce_evaluate(&mut pairwise[0], c) { + pending_degree1_eval = Some(next_coeffs); + } else { + // Fallback: separate reduce + pairwise::reduce_evaluations(&mut pairwise[0], c); + } + } else { + for table in pairwise.iter_mut() { + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + if crate::simd_sumcheck::dispatch::try_simd_reduce(table, c) { + continue; + } + pairwise::reduce_evaluations(table, c); + } } } @@ -50,48 +357,75 @@ pub fn coefficient_sumcheck( } } -/// Sumcheck verifier for arbitrary-degree round polynomials in coefficient form. -/// -/// Each round: absorb coefficients → check `h(0) + h(1) == claim` -/// → squeeze challenge → update `claim = h(challenge)`. -pub fn sumcheck_verify( - claim: &mut F, - prover_messages: &[DensePolynomial], - transcript: &mut impl Transcript, -) -> Option> { - let mut challenges = Vec::with_capacity(prover_messages.len()); - - for h in prover_messages { - for coeff in &h.coeffs { - transcript.write(*coeff); - } - - if h.evaluate(&F::zero()) + h.evaluate(&F::one()) != *claim { - return None; - } - - let c = transcript.read(); - *claim = h.evaluate(&c); - challenges.push(c); - } - - Some(challenges) -} - #[cfg(test)] mod tests { use super::*; use ark_ff::UniformRand; - use ark_poly::DenseUVPolynomial; + use ark_poly::Polynomial; use ark_std::test_rng; - use crate::multilinear::reductions::pairwise; use crate::tests::F64; use crate::transcript::SanityTranscript; + // ── Reusable evaluators for tests ─────────────────────────────────── + + /// Degree-1 evaluator: h(x) = even + (odd - even) * x per pair. + struct Degree1Evaluator; + impl RoundPolyEvaluator for Degree1Evaluator { + fn degree(&self) -> usize { + 1 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (even, odd) = pw[0]; + coeffs[0] += even; + coeffs[1] += odd - even; + } + } + + /// Degree-2 evaluator: interpolate through (0, s0), (1, s1), (2, s0+s1). + struct Degree2Evaluator; + impl RoundPolyEvaluator for Degree2Evaluator { + fn degree(&self) -> usize { + 2 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (s0, s1) = pw[0]; + let s2 = s0 + s1; + coeffs[0] += s0; + coeffs[1] += (-F64::from(3u64) * s0 + F64::from(4u64) * s1 - s2) / F64::from(2u64); + coeffs[2] += (s0 - F64::from(2u64) * s1 + s2) / F64::from(2u64); + } + } + + /// Mixed evaluator: tablewise column 0 + pairwise even (degree 0). + struct MixedEvaluator; + impl RoundPolyEvaluator for MixedEvaluator { + fn degree(&self) -> usize { + 0 + } + fn accumulate_pair(&self, coeffs: &mut [F64], tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + coeffs[0] += tw[0].0[0] + pw[0].0; + } + } + + /// Inner product evaluator: per-pair product from two pairwise tables. + struct InnerProductEvaluator; + impl RoundPolyEvaluator for InnerProductEvaluator { + fn degree(&self) -> usize { + 1 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (a_even, a_odd) = pw[0]; + let (b_even, b_odd) = pw[1]; + coeffs[0] += a_even * b_even; + coeffs[1] += a_odd * b_odd - a_even * b_even; + } + } + + // ── Tests ─────────────────────────────────────────────────────────── + #[test] fn test_sumcheck_relation_holds_each_round() { - // verify h(0) + h(1) == claimed sum at each round let mut rng = test_rng(); let n = 1 << 4; let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); @@ -102,11 +436,7 @@ mod tests { let mut transcript = SanityTranscript::new(&mut rng); let result = coefficient_sumcheck( - |_tablewise, pairwise| { - let s0: F64 = pairwise[0].iter().step_by(2).copied().sum(); - let s1: F64 = pairwise[0].iter().skip(1).step_by(2).copied().sum(); - DensePolynomial::from_coefficients_vec(vec![s0, s1 - s0]) - }, + &Degree1Evaluator, &mut tablewise, &mut pairwise, 4, @@ -132,52 +462,6 @@ mod tests { } } - #[test] - fn test_parity_with_multilinear_sumcheck() { - // separate rng for evals so transcript rngs start at the same state - use crate::multilinear_sumcheck; - - let mut eval_rng = test_rng(); - let n = 1 << 4; - let evals: Vec = (0..n).map(|_| F64::rand(&mut eval_rng)).collect(); - let evals_clone = evals.clone(); - - // run multilinear_sumcheck - let mut rng1 = test_rng(); - let mut ml_evals = evals; - let mut ml_transcript = SanityTranscript::new(&mut rng1); - let ml_result = multilinear_sumcheck::(&mut ml_evals, &mut ml_transcript); - - // run coefficient_sumcheck with degree-1 compute_h - let mut rng2 = test_rng(); - let mut pairwise = vec![evals_clone]; - let mut tablewise: Vec>> = vec![]; - let mut coeff_transcript = SanityTranscript::new(&mut rng2); - let coeff_result = coefficient_sumcheck( - |_tablewise, pairwise| { - let (s0, s1) = pairwise::evaluate(&pairwise[0]); - DensePolynomial::from_coefficients_vec(vec![s0, s1 - s0]) - }, - &mut tablewise, - &mut pairwise, - 4, - &mut coeff_transcript, - ); - - // challenges must match - assert_eq!(ml_result.verifier_messages, coeff_result.verifier_messages); - - // round polynomials must be equivalent: (s0, s1) ↔ [s0, s1-s0] - for (ml_msg, coeff_msg) in ml_result - .prover_messages - .iter() - .zip(coeff_result.prover_messages.iter()) - { - assert_eq!(coeff_msg.evaluate(&F64::from(0u64)), ml_msg.0); - assert_eq!(coeff_msg.evaluate(&F64::from(1u64)), ml_msg.1); - } - } - #[test] fn test_spongefish_transcript() { use crate::transcript::SpongefishTranscript; @@ -187,7 +471,8 @@ mod tests { let num_rounds = 3; let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); - let domsep = spongefish::domain_separator!("test-coefficient-sumcheck"; module_path!()) + let domsep = spongefish::domain_separator!("test-coefficient-sumcheck") + .without_session() .instance(b"test"); let prover_state = domsep.std_prover(); @@ -197,11 +482,7 @@ mod tests { let mut tablewise: Vec>> = vec![]; let result = coefficient_sumcheck( - |_tablewise, pairwise| { - let s0: F64 = pairwise[0].iter().step_by(2).copied().sum(); - let s1: F64 = pairwise[0].iter().skip(1).step_by(2).copied().sum(); - DensePolynomial::from_coefficients_vec(vec![s0, s1 - s0]) - }, + &Degree1Evaluator, &mut tablewise, &mut pairwise, num_rounds, @@ -227,12 +508,7 @@ mod tests { let mut transcript = SanityTranscript::new(&mut rng); let result = coefficient_sumcheck( - |tablewise, pairwise| { - // combine both: sum of tablewise column 0 + pairwise even elements - let ts: F64 = tablewise[0].iter().map(|row| row[0]).sum(); - let ps: F64 = pairwise[0].iter().step_by(2).copied().sum(); - DensePolynomial::from_coefficients_vec(vec![ts + ps]) - }, + &MixedEvaluator, &mut tablewise, &mut pairwise, 3, @@ -240,15 +516,12 @@ mod tests { ); assert_eq!(result.prover_messages.len(), 3); - // both should be reduced to single entries assert_eq!(tablewise[0].len(), 1); assert_eq!(pairwise[0].len(), 1); } #[test] fn test_higher_degree_round_polys() { - // degree-2 round poly: h(0) = s0, h(1) = s1, h(2) = s0 + s1 - // verify the sumcheck relation holds at each round let mut rng = test_rng(); let n = 1 << 3; let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); @@ -259,31 +532,19 @@ mod tests { let mut transcript = SanityTranscript::new(&mut rng); let result = coefficient_sumcheck( - |_tablewise, pairwise| { - let s0: F64 = pairwise[0].iter().step_by(2).copied().sum(); - let s1: F64 = pairwise[0].iter().skip(1).step_by(2).copied().sum(); - // degree-2: interpolate through (0, s0), (1, s1), (2, s0+s1) - // h(0)+h(1) = s0+s1 still holds, so sumcheck relation is satisfied - let s2 = s0 + s1; - let c0 = s0; - let c1 = (-F64::from(3u64) * s0 + F64::from(4u64) * s1 - s2) / F64::from(2u64); - let c2 = (s0 - F64::from(2u64) * s1 + s2) / F64::from(2u64); - DensePolynomial::from_coefficients_vec(vec![c0, c1, c2]) - }, + &Degree2Evaluator, &mut tablewise, &mut pairwise, 3, &mut transcript, ); - // verify round 0: h(0) + h(1) == claimed sum let h0 = &result.prover_messages[0]; assert_eq!( h0.evaluate(&F64::from(0u64)) + h0.evaluate(&F64::from(1u64)), claimed_sum ); - // all round polys should be degree 2 for h in &result.prover_messages { assert_eq!(h.coeffs.len(), 3); } @@ -300,11 +561,7 @@ mod tests { let mut transcript = SanityTranscript::new(&mut rng); let result = coefficient_sumcheck( - |_tablewise, pairwise| { - let s0 = pairwise[0][0]; - let s1 = pairwise[0][1]; - DensePolynomial::from_coefficients_vec(vec![s0, s1 - s0]) - }, + &Degree1Evaluator, &mut tablewise, &mut pairwise, 1, @@ -324,7 +581,6 @@ mod tests { #[test] fn test_multiple_pairwise_tables() { - // two independent pairwise tables, both reduced let mut rng = test_rng(); let n = 1 << 3; let evals_a: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); @@ -335,23 +591,7 @@ mod tests { let mut transcript = SanityTranscript::new(&mut rng); let result = coefficient_sumcheck( - |_tablewise, pairwise| { - // inner product contribution from both tables - let s0: F64 = pairwise[0] - .iter() - .zip(pairwise[1].iter()) - .step_by(2) - .map(|(a, b)| *a * b) - .sum(); - let s1: F64 = pairwise[0] - .iter() - .zip(pairwise[1].iter()) - .skip(1) - .step_by(2) - .map(|(a, b)| *a * b) - .sum(); - DensePolynomial::from_coefficients_vec(vec![s0, s1 - s0]) - }, + &InnerProductEvaluator, &mut tablewise, &mut pairwise, 3, diff --git a/src/field.rs b/src/field.rs new file mode 100644 index 00000000..b464a155 --- /dev/null +++ b/src/field.rs @@ -0,0 +1,385 @@ +//! Generic field trait for sumcheck. +//! +//! [`SumcheckField`] captures the minimum arithmetic interface needed by the +//! sumcheck protocol. Any type with field-like operations (add, sub, mul, +//! negate, invert) and two distinguished constants (zero, one) can implement +//! this trait and use the full sumcheck library. +//! +//! When the `arkworks` feature is enabled (default), a blanket implementation +//! is provided for all types implementing [`ark_ff::Field`], so existing +//! arkworks users change nothing. + +use core::fmt::Debug; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// Minimum field interface for the sumcheck protocol. +/// +/// Implementors must provide: +/// - Standard arithmetic via [`Add`], [`Sub`], [`Mul`], [`Neg`] and their +/// assign variants. +/// - Additive and multiplicative identities ([`ZERO`](Self::ZERO), +/// [`ONE`](Self::ONE)). +/// - Conversion from small integers ([`from_u64`](Self::from_u64)). +/// - Multiplicative inverse ([`inverse`](Self::inverse)). +/// +/// The SIMD acceleration layer for Goldilocks (p = 2^64 − 2^32 + 1) is +/// opt-in via [`_simd_field_config`](Self::_simd_field_config). Non-Goldilocks +/// fields leave the default (returns `None`) and the library transparently +/// falls back to scalar code. +pub trait SumcheckField: + Sized + + Copy + + Send + + Sync + + PartialEq + + Debug + + Add + + Sub + + Mul + + Neg + + AddAssign + + SubAssign + + MulAssign + + Sum + + 'static +{ + /// Additive identity. + const ZERO: Self; + + /// Multiplicative identity. + const ONE: Self; + + /// Convert a small integer to a field element. + fn from_u64(val: u64) -> Self; + + /// Multiplicative inverse, or `None` for zero. + fn inverse(&self) -> Option; + + /// Returns `true` if this element is zero. + #[inline] + fn is_zero(&self) -> bool { + *self == Self::ZERO + } + + /// Double this element (additive). + #[inline] + fn double(&self) -> Self { + *self + *self + } + + /// Extension degree over the base prime field. + /// + /// For a prime field, return 1. For a quadratic extension, return 2, etc. + /// Used by the SIMD layer to select the correct kernel width. + fn extension_degree() -> u64 { + 1 + } + + /// SIMD configuration (internal dispatch hook). + /// + /// For **arkworks types**: overridden by the blanket impl to auto-detect + /// Goldilocks from the modulus. Const-folded by LLVM. + /// + /// For **non-arkworks types**: override this to return the config from + /// your [`SimdRepr`] implementation. Or simply implement `SimdRepr` and + /// override this method to call [`simd_config_from_repr`]: + /// + /// ```ignore + /// fn _simd_field_config() -> Option { + /// Some(simd_config_from_repr::()) + /// } + /// ``` + #[doc(hidden)] + #[inline(always)] + fn _simd_field_config() -> Option { + None + } + + /// Reinterpret a field element as its raw Montgomery-form `u64`. + /// + /// Only meaningful for Goldilocks-based fields where each base-field + /// component is a single `u64` in Montgomery form. The default panics; + /// overridden by the arkworks blanket impl and by `SimdRepr` types. + #[doc(hidden)] + #[inline(always)] + fn _to_raw_u64(self) -> u64 { + unimplemented!("_to_raw_u64 is only available for SIMD-compatible fields") + } + + /// Reconstruct a field element from a raw Montgomery-form `u64`. + #[doc(hidden)] + #[inline(always)] + fn _from_raw_u64(raw: u64) -> Self { + let _ = raw; + unimplemented!("_from_raw_u64 is only available for SIMD-compatible fields") + } + + /// Reinterpret a slice of field elements as a flat `u64` slice. + /// + /// For base fields (degree 1), the output has the same length as input. + /// For extension fields (degree d), the output has `input.len() * d` elements. + #[doc(hidden)] + #[inline(always)] + fn _as_u64_slice(slice: &[Self]) -> &[u64] { + let _ = slice; + unimplemented!("_as_u64_slice is only available for SIMD-compatible fields") + } + + /// Reinterpret a mutable slice of field elements as a mutable flat `u64` slice. + #[doc(hidden)] + #[inline(always)] + fn _as_u64_slice_mut(slice: &mut [Self]) -> &mut [u64] { + let _ = slice; + unimplemented!("_as_u64_slice_mut is only available for SIMD-compatible fields") + } + + /// Reconstruct a field element from its raw Montgomery-form `u64` components. + /// + /// For base fields, `comps` has length 1. For degree-d extensions, length d. + #[doc(hidden)] + #[inline(always)] + fn _from_u64_components(comps: &[u64]) -> Self { + let _ = comps; + unimplemented!("_from_u64_components is only available for SIMD-compatible fields") + } +} + +// ─── SIMD memory layout contract ──────────────────────────────────────────── + +/// Goldilocks modulus: p = 2^64 - 2^32 + 1. +pub const GOLDILOCKS_P: u64 = 0xFFFF_FFFF_0000_0001; + +/// Opt-in trait for SIMD acceleration. +/// +/// Implementing this trait declares that the field type's in-memory +/// representation is compatible with the SIMD kernels: each element is +/// `extension_degree()` consecutive `u64` values in Montgomery form. +/// +/// # Layout safety +/// +/// The memory layout guarantee is enforced by the **zerocopy** bounds +/// (`IntoBytes + FromBytes + Immutable`). These are verified at compile +/// time by zerocopy's derive macros — no `unsafe` needed from the +/// implementor. If your type's layout doesn't support safe byte +/// reinterpretation, the derive will fail to compile. +/// +/// The only thing the implementor declares is the **modulus value**. +/// A wrong modulus produces wrong arithmetic results (logic bug), not +/// undefined behavior. +/// +/// # Example: non-arkworks Goldilocks +/// +/// ```ignore +/// #[derive(Clone, Copy, Debug, PartialEq, +/// zerocopy::IntoBytes, zerocopy::FromBytes, zerocopy::Immutable)] +/// #[repr(transparent)] +/// struct MyGoldilocks(u64); +/// +/// impl SumcheckField for MyGoldilocks { +/// // ... arithmetic ... +/// fn _simd_field_config() -> Option { +/// Some(SimdFieldConfig { modulus: GOLDILOCKS_P, element_bytes: 8 }) +/// } +/// } +/// +/// impl SimdRepr for MyGoldilocks { +/// fn modulus() -> u64 { GOLDILOCKS_P } +/// } +/// ``` +/// +/// # Example: Goldilocks cubic extension +/// +/// ```ignore +/// #[derive(Clone, Copy, Debug, PartialEq, +/// zerocopy::IntoBytes, zerocopy::FromBytes, zerocopy::Immutable)] +/// #[repr(transparent)] +/// struct MyExt3([u64; 3]); +/// +/// impl SumcheckField for MyExt3 { +/// fn extension_degree() -> u64 { 3 } +/// fn _simd_field_config() -> Option { +/// Some(SimdFieldConfig { modulus: GOLDILOCKS_P, element_bytes: 8 }) +/// } +/// // ... +/// } +/// +/// impl SimdRepr for MyExt3 { +/// fn modulus() -> u64 { GOLDILOCKS_P } +/// } +/// ``` +pub trait SimdRepr: + SumcheckField + zerocopy::IntoBytes + zerocopy::FromBytes + zerocopy::Immutable +{ + /// The base prime field modulus as a single `u64` limb. + /// + /// For the Goldilocks SIMD kernels to fire, this must equal + /// [`GOLDILOCKS_P`] (`0xFFFF_FFFF_0000_0001`). + fn modulus() -> u64; +} + +/// SIMD field configuration. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SimdFieldConfig { + /// The base prime field modulus as a single `u64` limb. + pub modulus: u64, + /// Size of one base prime field element in bytes. + pub element_bytes: usize, +} + +/// Query SIMD configuration for a field type. +/// +/// Returns `Some(config)` if the field's memory layout is SIMD-compatible. +/// +/// Two paths to SIMD: +/// +/// 1. **Arkworks types** (automatic): the `arkworks` feature detects +/// Goldilocks from the modulus at compile time. Zero overhead. +/// +/// 2. **Non-arkworks types** (explicit): implement [`SimdRepr`] (requires +/// `zerocopy::IntoBytes + FromBytes + Immutable` — compiler-verified +/// layout) and override `_simd_field_config()`. +#[inline(always)] +pub fn simd_config() -> Option { + F::_simd_field_config() +} + +/// Marker trait for cross-field (base field → extension field) sumcheck. +/// +/// The prover starts with evaluations in `BF` and folds into `EF` once +/// the first challenge arrives. +pub trait ExtensionOf: SumcheckField + From {} + +// ─── Arkworks blanket implementation ──────────────────────────────────────── + +#[cfg(feature = "arkworks")] +mod ark_impl { + use super::*; + + impl SumcheckField for F + where + F: ark_ff::Field, + { + const ZERO: Self = ::ZERO; + const ONE: Self = ::ONE; + + #[inline] + fn from_u64(val: u64) -> Self { + Self::from(val) + } + + #[inline] + fn inverse(&self) -> Option { + ark_ff::Field::inverse(self) + } + + #[inline] + fn is_zero(&self) -> bool { + *self == Self::ZERO + } + + #[inline] + fn double(&self) -> Self { + ark_ff::AdditiveGroup::double(self) + } + + #[inline] + fn extension_degree() -> u64 { + ::extension_degree() + } + + #[inline(always)] + fn _simd_field_config() -> Option { + use ark_ff::PrimeField; + + // Check if the base prime field is a single u64 (64-bit modulus). + if F::BasePrimeField::MODULUS_BIT_SIZE != 64 { + return None; + } + let d = ::extension_degree() as usize; + if core::mem::size_of::() != d * 8 { + return None; + } + let modulus = F::BasePrimeField::MODULUS; + let limbs: &[u64] = modulus.as_ref(); + if limbs[1..].iter().any(|&x| x != 0) { + return None; + } + Some(SimdFieldConfig { + modulus: limbs[0], + element_bytes: 8, + }) + } + + #[doc(hidden)] + #[inline(always)] + fn _to_raw_u64(self) -> u64 { + debug_assert_eq!(core::mem::size_of::(), 8); + // SAFETY: The caller has verified via `_simd_field_config()` that + // Self is a Goldilocks-compatible field with `size_of::() == 8` + // and `element_bytes == 8`. Arkworks `SmallFp` and `Fp, 1>` + // both store exactly one u64 in Montgomery form as their sole field, + // making this transmute a no-op reinterpretation. + unsafe { core::mem::transmute_copy(&self) } + } + + #[doc(hidden)] + #[inline(always)] + fn _from_raw_u64(raw: u64) -> Self { + debug_assert_eq!(core::mem::size_of::(), 8); + // SAFETY: Same layout guarantee as `_to_raw_u64` — the raw u64 + // is a valid Montgomery-form value for this Goldilocks field. + unsafe { core::mem::transmute_copy(&raw) } + } + + #[doc(hidden)] + #[inline(always)] + fn _as_u64_slice(slice: &[Self]) -> &[u64] { + let d = ::extension_degree() as usize; + debug_assert_eq!(core::mem::size_of::(), d * 8); + let n_u64 = slice.len() * d; + // SAFETY: Each element is `d` consecutive u64 values in Montgomery + // form (verified by `_simd_field_config()`). The slice is contiguous + // in memory, so the reinterpretation as `&[u64]` of length + // `slice.len() * d` is valid. Lifetime and alignment are preserved + // because u64 alignment (8) divides the field element alignment. + unsafe { core::slice::from_raw_parts(slice.as_ptr() as *const u64, n_u64) } + } + + #[doc(hidden)] + #[inline(always)] + fn _as_u64_slice_mut(slice: &mut [Self]) -> &mut [u64] { + let d = ::extension_degree() as usize; + debug_assert_eq!(core::mem::size_of::(), d * 8); + let n_u64 = slice.len() * d; + // SAFETY: Same layout guarantee as `_as_u64_slice`. Mutable access + // is exclusive because we hold `&mut [Self]`. + unsafe { core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u64, n_u64) } + } + + #[doc(hidden)] + #[inline(always)] + fn _from_u64_components(comps: &[u64]) -> Self { + let d = ::extension_degree() as usize; + assert_eq!(comps.len(), d); + assert_eq!(core::mem::size_of::(), d * 8); + // SAFETY: `comps` contains exactly `d` valid Montgomery-form u64 + // values and `size_of::() == d * 8` (verified by + // `_simd_field_config()`). The copy reconstructs the in-memory + // layout of the field element. + unsafe { + let mut out = core::mem::MaybeUninit::::uninit(); + core::ptr::copy_nonoverlapping(comps.as_ptr(), out.as_mut_ptr() as *mut u64, d); + out.assume_init() + } + } + } + + // Blanket ExtensionOf for arkworks extension fields. + impl ExtensionOf for EF + where + BF: ark_ff::Field, + EF: ark_ff::Field + From, + { + } +} diff --git a/src/fold.rs b/src/fold.rs new file mode 100644 index 00000000..ef9051af --- /dev/null +++ b/src/fold.rs @@ -0,0 +1,14 @@ +//! The fold primitive (Thaler Lemma 4.3, equation 4.13). +//! +//! Half-split (MSB) fold: `new[k] = v[k] + weight · (v[k + L/2] − v[k])`. +//! +//! SIMD-accelerated for Goldilocks on NEON (aarch64) and AVX-512 IFMA +//! (x86_64). Falls back to scalar code for other fields. The detection +//! is compile-time constant-folded — zero overhead on the scalar path. +//! +//! This is exposed as a standalone public function because callers +//! (e.g., WHIR's `multilinear_fold`) need it independently of the full +//! sumcheck protocol. + +#[cfg(feature = "arkworks")] +pub use crate::multilinear_sumcheck::fold; diff --git a/src/folding.rs b/src/folding.rs index 0d4806dc..024c2c24 100644 --- a/src/folding.rs +++ b/src/folding.rs @@ -1,42 +1,77 @@ pub mod protogalaxy { - use ark_ff::Field; - use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; - #[cfg(feature = "parallel")] - use rayon::prelude::*; + use ark_ff::{Field, Zero}; + use ark_poly::{univariate::DensePolynomial, Polynomial}; + + use crate::poly_ops; /// Fold `n` polynomials using `log_n` linear coefficient pairs `(a, b)`. /// - /// At each level: `p[0] + (a + b·X)·(p[1] - p[0])`. + /// At each level: `result[i] = p[2i] + (a + b·X)·(p[2i+1] - p[2i])`. + /// + /// Uses [`poly_ops`] for zero-allocation arithmetic on flat coefficient buffers. + /// Each polynomial at level `k` has degree ≤ initial_degree + `k`, + /// stored in fixed-width slots. pub fn fold( coeffs: impl Iterator, - mut polys: Vec>, + polys: Vec>, ) -> DensePolynomial { - for (a, b) in coeffs { - #[cfg(feature = "parallel")] - { - polys = polys - .par_chunks(2) - .map(|p| { - &p[0] - + DensePolynomial::from_coefficients_vec(vec![a, b]) - .naive_mul(&(&p[1] - &p[0])) - }) - .collect(); - } - #[cfg(not(feature = "parallel"))] - { - polys = polys - .chunks(2) - .map(|p| { - &p[0] - + DensePolynomial::from_coefficients_vec(vec![a, b]) - .naive_mul(&(&p[1] - &p[0])) - }) - .collect(); + let coeffs_vec: Vec<(F, F)> = coeffs.collect(); + let n_levels = coeffs_vec.len(); + + if polys.is_empty() { + return DensePolynomial::zero(); + } + if polys.len() == 1 { + return polys.into_iter().next().unwrap(); + } + + let init_max_deg = polys.iter().map(|p| p.degree()).max().unwrap_or(0); + let final_max_deg = init_max_deg + n_levels; + let slot = final_max_deg + 1; + + // Pack into flat buffer with fixed-width slots. + let mut n_polys = polys.len(); + let mut buf = vec![F::ZERO; n_polys * slot]; + for (i, p) in polys.into_iter().enumerate() { + poly_ops::copy_into(&mut buf[i * slot..], &p.coeffs); + } + + let mut cur_deg = init_max_deg; + let mut diff = vec![F::ZERO; slot]; + + for &(a, b) in &coeffs_vec { + let half = n_polys / 2; + + for i in 0..half { + let p0_off = (2 * i) * slot; + let p1_off = (2 * i + 1) * slot; + let out_off = i * slot; + let deg = cur_deg + 1; // new degree after this level + + // diff[0..=cur_deg] = p1 - p0 + poly_ops::sub_into( + &mut diff[..=cur_deg], + &buf[p1_off..p1_off + cur_deg + 1], + &buf[p0_off..p0_off + cur_deg + 1], + ); + + // result = p0 + a·diff + b·X·diff + // Process high-to-low to allow in-place when out_off ≤ p0_off. + buf[out_off + deg] = b * diff[cur_deg]; + for j in (1..=cur_deg).rev() { + buf[out_off + j] = buf[p0_off + j] + a * diff[j] + b * diff[j - 1]; + } + buf[out_off] = buf[p0_off] + a * diff[0]; + + poly_ops::zero(&mut buf[out_off + deg + 1..out_off + slot]); } + + cur_deg += 1; + n_polys = half; } - assert_eq!(polys.len(), 1); - polys.pop().unwrap() + + debug_assert_eq!(n_polys, 1); + poly_ops::to_dense_poly(&buf[..=cur_deg.min(final_max_deg)]) } } @@ -77,4 +112,52 @@ mod tests { assert_eq!(result.coeffs.len(), 1); assert_eq!(result.coeffs[0], F64::from(4u64)); } + + #[test] + fn test_fold_matches_naive() { + // Compare optimized fold against a naive reference for random inputs. + use ark_ff::UniformRand; + use ark_std::test_rng; + + let mut rng = test_rng(); + + // 8 random degree-2 polynomials, 3 fold levels + let polys: Vec> = (0..8) + .map(|_| { + DensePolynomial::from_coefficients_vec(vec![ + F64::rand(&mut rng), + F64::rand(&mut rng), + F64::rand(&mut rng), + ]) + }) + .collect(); + + let coeffs: Vec<(F64, F64)> = (0..3) + .map(|_| (F64::rand(&mut rng), F64::rand(&mut rng))) + .collect(); + + // Naive fold (original algorithm) + let naive_result = { + let mut ps = polys.clone(); + for &(a, b) in &coeffs { + ps = ps + .chunks(2) + .map(|p| { + &p[0] + + DensePolynomial::from_coefficients_vec(vec![a, b]) + .naive_mul(&(&p[1] - &p[0])) + }) + .collect(); + } + ps.pop().unwrap() + }; + + // Optimized fold + let opt_result = fold(coeffs.into_iter(), polys); + + assert_eq!(naive_result.coeffs.len(), opt_result.coeffs.len()); + for (n, o) in naive_result.coeffs.iter().zip(opt_result.coeffs.iter()) { + assert_eq!(*n, *o, "coefficient mismatch"); + } + } } diff --git a/src/hypercube/eq_evals.rs b/src/hypercube/eq_evals.rs index 82efd34f..86cc6db4 100644 --- a/src/hypercube/eq_evals.rs +++ b/src/hypercube/eq_evals.rs @@ -1,20 +1,64 @@ use ark_ff::Field; -use crate::hypercube::Hypercube; -use crate::order_strategy::AscendingOrder; - -/// Compute eq(point, ·) over {0,1}^num_variables. +/// Compute eq(τ, ·) over {0,1}^v using the incremental build-up algorithm. /// /// `eq(x, y) = Π_j (x_j · y_j + (1 - x_j)(1 - y_j))`. +/// +/// Returns a table of length `2^v` where entry `i` is `eq(τ, y_i)` with +/// `y_i` the Boolean vector whose bit `j` is `(i >> j) & 1` (LSB-indexed). +/// +/// Complexity: O(2^v) — one multiply per entry. The algorithm doubles the +/// table one variable at a time: for each `τ_j`, existing entries are +/// split into `entry * (1 - τ_j)` (bit 0) and `entry * τ_j` (bit 1). pub fn compute_hypercube_eq_evals(num_variables: usize, point: &[F]) -> Vec { - Hypercube::::new(num_variables) - .map(|(index, _)| { - (0..num_variables).fold(F::one(), |acc, j| { - let bit = F::from((index >> j & 1) as u64); - acc * (point[j] * bit + (F::one() - point[j]) * (F::one() - bit)) - }) - }) - .collect() + let size = 1 << num_variables; + let mut table = Vec::with_capacity(size); + table.push(F::one()); + + for &tau_j in point[..num_variables].iter().rev() { + let len = table.len(); + let one_minus = F::one() - tau_j; + // Process in reverse so we can expand in place. + table.resize(2 * len, F::zero()); + for i in (0..len).rev() { + table[2 * i + 1] = table[i] * tau_j; + table[2 * i] = table[i] * one_minus; + } + } + + table +} + +/// Evaluate `eq(τ, y)` at a single Boolean point `y ∈ {0,1}^v`. +/// +/// `point` is the integer whose bit `j` is `y_j` (LSB-indexed). +/// This is O(v) — use it when you need one entry instead of the full +/// table from [`compute_hypercube_eq_evals`]. +/// +/// `eq_poly(τ, i) == compute_hypercube_eq_evals(τ.len(), τ)[i]`. +pub fn eq_poly(tau: &[F], point: usize) -> F { + let num_variables = tau.len(); + (0..num_variables).fold(F::one(), |acc, j| { + if (point >> j) & 1 == 1 { + acc * tau[j] + } else { + acc * (F::one() - tau[j]) + } + }) +} + +/// Evaluate `eq(x, y)` where both `x` and `y` are field element vectors. +/// +/// `eq(x, y) = Π_j (x_j · y_j + (1 − x_j)(1 − y_j))`. +/// +/// Unlike [`eq_poly`] which takes a Boolean point as an integer, this +/// handles non-binary evaluation points — needed for oracle checks in +/// composed protocols (WARP, GKR reduce-to-one). +pub fn eq_poly_non_binary(x: &[F], y: &[F]) -> F { + assert_eq!(x.len(), y.len()); + x.iter().zip(y).fold(F::one(), |acc, (x_i, y_i)| { + acc * (*x_i * *y_i + (F::one() - x_i) * (F::one() - y_i)) + }) } #[cfg(test)] @@ -24,11 +68,6 @@ mod tests { #[test] fn test_eq_evals_2_variables() { - // point = [2, 3], ascending bit order - // index 0 = (0,0): (-1)·(-2) = 2 - // index 1 = (1,0): (2)·(-2) = -4 - // index 2 = (0,1): (-1)·(3) = -3 - // index 3 = (1,1): (2)·(3) = 6 let point = vec![F64::from(2u64), F64::from(3u64)]; let evals = compute_hypercube_eq_evals(2, &point); @@ -41,7 +80,6 @@ mod tests { #[test] fn test_eq_evals_sum_is_one_on_binary_point() { - // eq(b, ·) is 1 at b and 0 elsewhere let point = vec![F64::from(1u64), F64::from(0u64), F64::from(1u64)]; let evals = compute_hypercube_eq_evals(3, &point); @@ -55,4 +93,35 @@ mod tests { } } } + + #[test] + fn test_eq_poly_matches_table() { + let tau = vec![F64::from(2u64), F64::from(3u64)]; + let table = compute_hypercube_eq_evals(2, &tau); + for i in 0..4 { + assert_eq!(eq_poly(&tau, i), table[i], "mismatch at point {i}"); + } + } + + #[test] + fn test_eq_poly_non_binary_matches_table_on_binary() { + // When y is binary, eq_poly_non_binary should match eq_poly. + let tau = vec![F64::from(2u64), F64::from(3u64)]; + for i in 0..4usize { + let y: Vec = (0..2).map(|j| F64::from(((i >> j) & 1) as u64)).collect(); + assert_eq!( + eq_poly_non_binary(&tau, &y), + eq_poly(&tau, i), + "mismatch at point {i}" + ); + } + } + + #[test] + fn test_eq_poly_non_binary_symmetric() { + // eq(x, y) == eq(y, x) + let x = vec![F64::from(5u64), F64::from(7u64), F64::from(3u64)]; + let y = vec![F64::from(11u64), F64::from(2u64), F64::from(9u64)]; + assert_eq!(eq_poly_non_binary(&x, &y), eq_poly_non_binary(&y, &x)); + } } diff --git a/src/hypercube/hypercube.rs b/src/hypercube/hypercube.rs deleted file mode 100644 index 95dd6603..00000000 --- a/src/hypercube/hypercube.rs +++ /dev/null @@ -1,181 +0,0 @@ -use crate::{hypercube::HypercubeMember, order_strategy::OrderStrategy}; - -// mod hypercube; -// mod hypercube_member; - -// pub use hypercube::Hypercube; -// pub use hypercube_member::HypercubeMember; - -// On each call to next() this gives a HypercubeMember for the value -#[derive(Debug)] -pub struct Hypercube { - order: O, -} - -impl Hypercube { - pub fn new(num_vars: usize) -> Self { - let order = O::new(num_vars); - Self { order } - } - pub fn stop_value(num_vars: usize) -> usize { - 1 << num_vars // this is exclusive, meaning should stop *before* this value - } -} - -impl Iterator for Hypercube { - type Item = (usize, HypercubeMember); - fn next(&mut self) -> Option { - match self.order.next_index() { - Some(current_index) => Some(( - current_index, - HypercubeMember::new(self.order.num_vars(), current_index), - )), - None => None, - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - hypercube::{Hypercube, HypercubeMember}, - order_strategy::{AscendingOrder, GraycodeOrder}, - }; - - fn is_eq(given: HypercubeMember, expected: Vec) { - // check each value in the vec - for (i, (a, b)) in given.zip(expected.clone()).enumerate() { - assert_eq!( - a, b, - "bit at index {} incorrect, should be {:?}", - i, expected - ); - } - } - - #[test] - fn lexicographic_hypercube_members() { - // for n=0, should return empty vec first call, none second call - let mut hypercube_size_0 = Hypercube::::new(0); - is_eq(hypercube_size_0.next().unwrap().1, vec![]); - // for n=1, should return vec[false] first call, vec[true] second call and None third call - let mut hypercube_size_1: Hypercube = Hypercube::new(1); - is_eq(hypercube_size_1.next().unwrap().1, vec![false]); - is_eq(hypercube_size_1.next().unwrap().1, vec![true]); - assert_eq!(hypercube_size_1.next(), None); - // so on for n=2 - let mut hypercube_size_2: Hypercube = Hypercube::new(2); - is_eq(hypercube_size_2.next().unwrap().1, vec![false, false]); - is_eq(hypercube_size_2.next().unwrap().1, vec![false, true]); - is_eq(hypercube_size_2.next().unwrap().1, vec![true, false]); - is_eq(hypercube_size_2.next().unwrap().1, vec![true, true]); - assert_eq!(hypercube_size_2.next(), None); - // so on for n=3 - let mut hypercube_size_3: Hypercube = Hypercube::new(3); - is_eq( - hypercube_size_3.next().unwrap().1, - vec![false, false, false], - ); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, false, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, true, false]); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, true, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, false, false]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, false, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, true, false]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, true, true]); - assert_eq!(hypercube_size_3.next(), None); - } - - #[test] - fn lexicographic_indices() { - // for n=0, should return empty vec first call, none second call - let mut hypercube_size_0 = Hypercube::::new(0); - assert_eq!(hypercube_size_0.next().unwrap().0, 0); - // for n=1, should return vec[false] first call, vec[true] second call and None third call - let mut hypercube_size_1: Hypercube = Hypercube::new(1); - assert_eq!(hypercube_size_1.next().unwrap().0, 0); - assert_eq!(hypercube_size_1.next().unwrap().0, 1); - assert_eq!(hypercube_size_1.next(), None); - // so on for n=2 - let mut hypercube_size_2: Hypercube = Hypercube::new(2); - assert_eq!(hypercube_size_2.next().unwrap().0, 0); - assert_eq!(hypercube_size_2.next().unwrap().0, 1); - assert_eq!(hypercube_size_2.next().unwrap().0, 2); - assert_eq!(hypercube_size_2.next().unwrap().0, 3); - assert_eq!(hypercube_size_2.next(), None); - // so on for n=3 - let mut hypercube_size_3: Hypercube = Hypercube::new(3); - assert_eq!(hypercube_size_3.next().unwrap().0, 0); - assert_eq!(hypercube_size_3.next().unwrap().0, 1); - assert_eq!(hypercube_size_3.next().unwrap().0, 2); - assert_eq!(hypercube_size_3.next().unwrap().0, 3); - assert_eq!(hypercube_size_3.next().unwrap().0, 4); - assert_eq!(hypercube_size_3.next().unwrap().0, 5); - assert_eq!(hypercube_size_3.next().unwrap().0, 6); - assert_eq!(hypercube_size_3.next().unwrap().0, 7); - assert_eq!(hypercube_size_3.next(), None); - } - - #[test] - fn graycode_hypercube_members() { - // for n=0, should return empty vec first call, none second call - let mut hypercube_size_0 = Hypercube::::new(0); - is_eq(hypercube_size_0.next().unwrap().1, vec![]); - // for n=1, should return vec[false] first call, vec[true] second call and None third call - let mut hypercube_size_1: Hypercube = Hypercube::new(1); - is_eq(hypercube_size_1.next().unwrap().1, vec![false]); - is_eq(hypercube_size_1.next().unwrap().1, vec![true]); - assert_eq!(hypercube_size_1.next(), None); - // so on for n=2 - let mut hypercube_size_2: Hypercube = Hypercube::new(2); - is_eq(hypercube_size_2.next().unwrap().1, vec![false, false]); - is_eq(hypercube_size_2.next().unwrap().1, vec![false, true]); - is_eq(hypercube_size_2.next().unwrap().1, vec![true, true]); - is_eq(hypercube_size_2.next().unwrap().1, vec![true, false]); - assert_eq!(hypercube_size_2.next(), None); - // so on for n=3 - let mut hypercube_size_3: Hypercube = Hypercube::new(3); - is_eq( - hypercube_size_3.next().unwrap().1, - vec![false, false, false], - ); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, false, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, true, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![false, true, false]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, true, false]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, true, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, false, true]); - is_eq(hypercube_size_3.next().unwrap().1, vec![true, false, false]); - assert_eq!(hypercube_size_3.next(), None); - } - - #[test] - fn graycode_indices() { - // for n=0, should return empty vec first call, none second call - let mut hypercube_size_0 = Hypercube::::new(0); - assert_eq!(hypercube_size_0.next().unwrap().0, 0); - // for n=1, should return vec[false] first call, vec[true] second call and None third call - let mut hypercube_size_1: Hypercube = Hypercube::new(1); - assert_eq!(hypercube_size_1.next().unwrap().0, 0); - assert_eq!(hypercube_size_1.next().unwrap().0, 1); - assert_eq!(hypercube_size_1.next(), None); - // so on for n=2 - let mut hypercube_size_2: Hypercube = Hypercube::new(2); - assert_eq!(hypercube_size_2.next().unwrap().0, 0); - assert_eq!(hypercube_size_2.next().unwrap().0, 1); - assert_eq!(hypercube_size_2.next().unwrap().0, 3); - assert_eq!(hypercube_size_2.next().unwrap().0, 2); - assert_eq!(hypercube_size_2.next(), None); - // so on for n=3 - let mut hypercube_size_3: Hypercube = Hypercube::new(3); - assert_eq!(hypercube_size_3.next().unwrap().0, 0); - assert_eq!(hypercube_size_3.next().unwrap().0, 1); - assert_eq!(hypercube_size_3.next().unwrap().0, 3); - assert_eq!(hypercube_size_3.next().unwrap().0, 2); - assert_eq!(hypercube_size_3.next().unwrap().0, 6); - assert_eq!(hypercube_size_3.next().unwrap().0, 7); - assert_eq!(hypercube_size_3.next().unwrap().0, 5); - assert_eq!(hypercube_size_3.next().unwrap().0, 4); - assert_eq!(hypercube_size_3.next(), None); - } -} diff --git a/src/hypercube/hypercube_member.rs b/src/hypercube/hypercube_member.rs deleted file mode 100644 index c5bfc229..00000000 --- a/src/hypercube/hypercube_member.rs +++ /dev/null @@ -1,116 +0,0 @@ -#[derive(Clone, Debug, PartialEq)] -pub struct HypercubeMember { - bit_index: usize, - num_vars: usize, - value: usize, -} - -impl HypercubeMember { - pub fn new(num_vars: usize, value: usize) -> Self { - assert!(num_vars <= std::mem::size_of::() * 8); - Self { - bit_index: num_vars, - num_vars, - value, - } - } - pub fn new_from_vec_bool(value: Vec) -> Self { - HypercubeMember::new(value.len(), HypercubeMember::usize_from_vec_bool(value)) - } - pub fn len(&self) -> usize { - self.num_vars - } - pub fn is_empty(&self) -> bool { - if self.bit_index == 0 { - return true; - } - false - } - pub fn usize_from_vec_bool(vec: Vec) -> usize { - vec.into_iter() - .rev() - .enumerate() - .fold(0, |acc, (i, bit)| acc | ((bit as usize) << i)) - } - pub fn elements_at_indices(b: Vec, indices: Vec) -> Vec { - // checks - if indices.is_empty() { - return vec![]; - } - assert!(b.len() >= indices.len()); - assert!(b.len() > *indices.last().unwrap()); - // get the indices - let mut b_prime: Vec = Vec::with_capacity(indices.len()); - for index in &indices { - b_prime.push(b[*index]); - } - b_prime - } - pub fn to_vec_bool(&self) -> Vec { - let mut b: Vec = Vec::with_capacity(self.num_vars); - for bit_index in (0..self.num_vars).rev() { - b.push(self.value & (1 << bit_index) != 0); - } - b - } - pub fn value(&self) -> usize { - self.value - } -} - -impl Iterator for HypercubeMember { - type Item = bool; - fn next(&mut self) -> Option { - // Check if n == 0 - if self.bit_index == 0 { - return None; - } - // Return if value is bit high at bit_index - self.bit_index -= 1; - let bit_mask = 1 << self.bit_index; - Some(self.value & bit_mask != 0) - } -} - -#[cfg(test)] -mod tests { - use crate::hypercube::HypercubeMember; - #[test] - fn elements_at_indices() { - let test_1 = vec![true, false, false, false, false]; - let indices_1 = vec![2, 3]; - let result_1 = HypercubeMember::elements_at_indices(test_1, indices_1); - assert_eq!(result_1, vec![false, false]); - let test_2 = vec![false, true, true, false, false, false, false, true]; - let indices_2 = vec![0, 1, 2, 4, 6]; - let result_2 = HypercubeMember::elements_at_indices(test_2, indices_2); - assert_eq!(result_2, vec![false, true, true, false, false]); - } - #[test] - fn vec_bool_to_usize() { - let test_1 = vec![true, false, false]; - let exp_1 = 4; - assert_eq!(HypercubeMember::usize_from_vec_bool(test_1), exp_1); - let test_2 = vec![false, true, true]; - let exp_2 = 3; - assert_eq!(HypercubeMember::usize_from_vec_bool(test_2), exp_2); - } - #[test] - fn to_vec_bool() { - let exp_1 = vec![true, false, false, false, false]; - let test_1 = HypercubeMember::new_from_vec_bool(exp_1.clone()); - assert_eq!(exp_1, test_1.to_vec_bool()); - let test_2 = HypercubeMember::new(5, 16); - assert_eq!(exp_1, test_2.to_vec_bool()); - - let exp_2 = vec![false, false, true, false, true]; - let test_3 = HypercubeMember::new_from_vec_bool(exp_2.clone()); - assert_eq!(exp_2, test_3.to_vec_bool()); - let test_4 = HypercubeMember::new(5, 5); - assert_eq!(exp_2, test_4.to_vec_bool()); - - let exp_3 = vec![false, false, true]; - let test_3 = HypercubeMember::new(3, 1); - assert_eq!(test_3.to_vec_bool(), exp_3); - } -} diff --git a/src/hypercube/iter.rs b/src/hypercube/iter.rs new file mode 100644 index 00000000..8403b7f4 --- /dev/null +++ b/src/hypercube/iter.rs @@ -0,0 +1,211 @@ +//! Hypercube iterators over `{0,1}^v` in two orders. +//! +//! Each order yields [`HypercubePoint`] structs containing the index and +//! number of variables. Bit access is via [`HypercubePoint::bit(j)`]. +//! +//! # Orders +//! +//! - [`Ascending`]: 0, 1, 2, ..., 2^v − 1 (LSB layout). Pairs `(2k, 2k+1)` +//! differ in the least-significant bit. Use for streaming/blendy provers. +//! - [`BitReverse`]: bit-reversal permutation (MSB layout). Pairs `(k, k+L/2)` +//! differ in the most-significant bit. Use for in-memory time provers. +//! +//! Both iterators are zero-allocation and `ExactSizeIterator`. + +/// A point on the Boolean hypercube `{0,1}^v`. +extern crate alloc; +use alloc::vec::Vec; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct HypercubePoint { + /// The index in the current iteration order. + pub index: usize, + /// Number of variables. + pub num_vars: usize, +} + +impl HypercubePoint { + /// Bit `j` of this point (0 or 1), where bit 0 is the least significant. + #[inline] + pub fn bit(&self, j: usize) -> bool { + debug_assert!(j < self.num_vars); + (self.index >> j) & 1 == 1 + } + + /// All bits as a `Vec`, LSB first. + pub fn bits(&self) -> Vec { + (0..self.num_vars).map(|j| self.bit(j)).collect() + } +} + +// ─── Ascending ───────────────────────────────────────────────────────────── + +/// Iterate over `{0,1}^v` in ascending order: 0, 1, 2, ..., 2^v − 1. +/// +/// This is the LSB (pair-split) layout. Adjacent pairs `(2k, 2k+1)` differ +/// in the least-significant bit. Use for streaming and blendy provers where +/// sequential memory access and cache-friendly prefetching matter. +pub struct Ascending { + num_vars: usize, + current: usize, + end: usize, +} + +impl Ascending { + pub fn new(num_vars: usize) -> Self { + Self { + num_vars, + current: 0, + end: 1 << num_vars, + } + } +} + +impl Iterator for Ascending { + type Item = HypercubePoint; + + #[inline] + fn next(&mut self) -> Option { + if self.current >= self.end { + return None; + } + let point = HypercubePoint { + index: self.current, + num_vars: self.num_vars, + }; + self.current += 1; + Some(point) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = self.end - self.current; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for Ascending {} + +// ─── BitReverse ──────────────────────────────────────────────────────────── + +/// Iterate over `{0,1}^v` in bit-reversal (MSB) order. +/// +/// Index `i` maps to `reverse_bits(i) >> (usize::BITS - v)`. +/// One hardware instruction on aarch64 (`RBIT`), fast on x86 too. +/// +/// This is the MSB (half-split) layout. Pairs `(k, k + 2^(v-1))` differ +/// in the most-significant bit. Use for in-memory time provers and WHIR. +pub struct BitReverse { + num_vars: usize, + shift: u32, + current: usize, + end: usize, +} + +impl BitReverse { + pub fn new(num_vars: usize) -> Self { + Self { + num_vars, + shift: if num_vars == 0 { + 0 + } else { + usize::BITS - num_vars as u32 + }, + current: 0, + end: 1 << num_vars, + } + } +} + +impl Iterator for BitReverse { + type Item = HypercubePoint; + + #[inline] + fn next(&mut self) -> Option { + if self.current >= self.end { + return None; + } + let reversed = if self.num_vars == 0 { + 0 + } else { + self.current.reverse_bits() >> self.shift + }; + let point = HypercubePoint { + index: reversed, + num_vars: self.num_vars, + }; + self.current += 1; + Some(point) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = self.end - self.current; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for BitReverse {} + +// ─── Tests ───────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ascending_3vars() { + let points: Vec = Ascending::new(3).map(|p| p.index).collect(); + assert_eq!(points, vec![0, 1, 2, 3, 4, 5, 6, 7]); + } + + #[test] + fn ascending_exact_size() { + let iter = Ascending::new(4); + assert_eq!(iter.len(), 16); + } + + #[test] + fn bit_reverse_3vars() { + let points: Vec = BitReverse::new(3).map(|p| p.index).collect(); + // 3-bit reversal: 000→000, 001→100, 010→010, 011→110, ... + assert_eq!(points, vec![0, 4, 2, 6, 1, 5, 3, 7]); + } + + #[test] + fn bit_reverse_is_permutation() { + let mut points: Vec = BitReverse::new(4).map(|p| p.index).collect(); + points.sort(); + assert_eq!(points, (0..16).collect::>()); + } + + #[test] + fn bit_reverse_0vars() { + let points: Vec = BitReverse::new(0).map(|p| p.index).collect(); + assert_eq!(points, vec![0]); + } + + #[test] + fn bit_reverse_involution() { + // Applying bit-reversal twice should give back the identity. + for num_vars in 1..=6 { + let once: Vec = BitReverse::new(num_vars).map(|p| p.index).collect(); + let mut twice = vec![0usize; once.len()]; + for (i, &r) in once.iter().enumerate() { + twice[r] = i; + } + // twice[bit_reverse(i)] = i, so twice should be the bit-reverse map again + let expected: Vec = BitReverse::new(num_vars).map(|p| p.index).collect(); + assert_eq!(twice, expected, "num_vars={num_vars}"); + } + } + + #[test] + fn point_bit_access() { + let p = HypercubePoint { + index: 0b101, + num_vars: 3, + }; + assert!(p.bit(0)); // LSB + assert!(!p.bit(1)); + assert!(p.bit(2)); // MSB + assert_eq!(p.bits(), vec![true, false, true]); + } +} diff --git a/src/hypercube/mod.rs b/src/hypercube/mod.rs index a855ac08..76a4bc85 100644 --- a/src/hypercube/mod.rs +++ b/src/hypercube/mod.rs @@ -1,8 +1,7 @@ +#[cfg(feature = "arkworks")] mod eq_evals; -#[allow(clippy::module_inception)] -mod hypercube; -mod hypercube_member; +mod iter; -pub use eq_evals::compute_hypercube_eq_evals; -pub use hypercube::Hypercube; -pub use hypercube_member::HypercubeMember; +#[cfg(feature = "arkworks")] +pub use eq_evals::{compute_hypercube_eq_evals, eq_poly, eq_poly_non_binary}; +pub use iter::{Ascending, BitReverse, HypercubePoint}; diff --git a/src/inner_product_sumcheck.rs b/src/inner_product_sumcheck.rs index a96f3756..bc9b99d4 100644 --- a/src/inner_product_sumcheck.rs +++ b/src/inner_product_sumcheck.rs @@ -1,217 +1,383 @@ -//! Inner product sumcheck protocol. +//! Quadratic inner-product sumcheck: `∑_x f(x)·g(x)`. //! -//! Given two evaluation vectors `f` and `g` representing multilinear polynomials on -//! the boolean hypercube `{0,1}^n`, the [`inner_product_sumcheck`] function executes -//! `n` rounds of the product sumcheck protocol computing `∑_x f(x)·g(x)`, and returns -//! the resulting [`ProductSumcheck`] transcript. +//! Half-split (MSB) layout with a fused fold+compute kernel. +//! Round `i` folds the top-most remaining variable — the split is over +//! `a[0..L/2]` vs `a[L/2..L]`, *not* the adjacent pairs `(a[2k], a[2k+1])` +//! of a pair-split (LSB) layout. Callers whose upstream indexing assumed +//! pair-split semantics must reorder their inputs with a bit-reversal. //! -//! The function is parameterized by two field types: -//! - `BF` (base field): the field the evaluations live in -//! - `EF` (extension field): the field challenges are sampled from +//! Wire format per round: `(c0, c2)` in *difference form*, where +//! - `c0 = q(0) = Σ a_lo·b_lo` +//! - `c2 = [x²] q(x) = Σ (a_hi − a_lo)·(b_hi − b_lo)` //! -//! When no extension field is needed, set `EF = BF`. +//! The verifier derives `c1 = claim − 2·c0 − c2` from the sumcheck +//! constraint `q(0) + q(1) = claim`. //! -//! # Example -//! -//! ```text -//! use efficient_sumcheck::{inner_product_sumcheck, ProductSumcheck}; -//! use efficient_sumcheck::transcript::SanityTranscript; -//! -//! // No extension field (BF = EF): -//! let mut f = vec![F::from(1), F::from(2), F::from(3), F::from(4)]; -//! let mut g = vec![F::from(5), F::from(6), F::from(7), F::from(8)]; -//! let mut transcript = SanityTranscript::new(&mut rng); -//! let result: ProductSumcheck = inner_product_sumcheck(&mut f, &mut g, &mut transcript); -//! ``` - -use ark_std::collections::HashMap; -use nohash_hasher::BuildNoHashHasher; +//! The fused kernel rolls the round-`i` fold into the round-`(i+1)` compute, +//! cutting memory traffic from 12 reads + 4 writes per quadruple to +//! 8 reads + 4 writes — roughly a 25% reduction on the cold path, with +//! additional cache-locality gains from reading all four strides +//! simultaneously. use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::join; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use crate::transcript::ProverTranscript; + +/// Legacy return type for `inner_product_sumcheck`. +#[derive(Debug, PartialEq)] +pub struct ProductSumcheck { + pub prover_messages: Vec<(F, F)>, + pub verifier_messages: Vec, + pub final_evaluations: (F, F), +} + +// ─── Workload threshold ───────────────────────────────────────────────────── + +/// Target single-thread workload size for `T`. Close to L1 cache. +const fn workload_size() -> usize { + #[cfg(all(target_arch = "aarch64", target_os = "macos"))] + const CACHE_SIZE: usize = 1 << 17; + #[cfg(all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ))] + const CACHE_SIZE: usize = 1 << 16; + #[cfg(target_arch = "x86_64")] + const CACHE_SIZE: usize = 1 << 15; + #[cfg(not(any( + all(target_arch = "aarch64", target_os = "macos"), + all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ), + target_arch = "x86_64" + )))] + const CACHE_SIZE: usize = 1 << 15; + + CACHE_SIZE / core::mem::size_of::() +} -use crate::{ - multilinear::{reductions::pairwise, ReduceMode}, - multilinear_product::{TimeProductProver, TimeProductProverConfig}, - prover::Prover, - streams::MemoryStream, -}; +// ─── Scalar helpers ───────────────────────────────────────────────────────── -use crate::transcript::Transcript; +fn dot(a: &[F], b: &[F]) -> F { + debug_assert_eq!(a.len(), b.len()); + #[cfg(feature = "parallel")] + if a.len() > workload_size::() { + return a.par_iter().zip(b).map(|(x, y)| *x * *y).sum(); + } + a.iter().zip(b).map(|(x, y)| *x * *y).sum() +} -pub use crate::multilinear_product::ProductSumcheck; +fn scalar_mul(v: &mut [F], w: F) { + for x in v.iter_mut() { + *x *= w; + } +} -pub type FastMap = HashMap>; +// ─── Core algebra ─────────────────────────────────────────────────────────── -pub fn batched_constraint_poly( - dense_polys: &Vec>, - sparse_polys: &FastMap, -) -> Vec { - fn sum_columns(matrix: &Vec>) -> Vec { - if matrix.is_empty() { - return vec![]; +/// `(c0, c2)` of the round polynomial `q(x) = c0 + c1·x + c2·x²`. +/// +/// Vectors `a` and `b` are implicitly zero-extended to the next power of two. +pub fn compute_sumcheck_polynomial(a: &[F], b: &[F]) -> (F, F) { + fn recurse(a0: &[F], a1: &[F], b0: &[F], b1: &[F]) -> (F, F) { + debug_assert_eq!(a0.len(), b0.len()); + debug_assert_eq!(a1.len(), b1.len()); + debug_assert!(a0.len() == a1.len()); + + #[cfg(feature = "parallel")] + if a0.len() * 4 > workload_size::() { + let mid = a0.len() / 2; + let (a0l, a0r) = a0.split_at(mid); + let (b0l, b0r) = b0.split_at(mid); + let (a1l, a1r) = a1.split_at(mid); + let (b1l, b1r) = b1.split_at(mid); + let (left, right) = join( + || recurse(a0l, a1l, b0l, b1l), + || recurse(a0r, a1r, b0r, b1r), + ); + return (left.0 + right.0, left.1 + right.1); } - let mut result = vec![F::ZERO; matrix[0].len()]; - for row in matrix { - for (i, &val) in row.iter().enumerate() { - result[i] += val; - } + let mut acc0 = F::ZERO; + let mut acc2 = F::ZERO; + for ((&a0, &a1), (&b0, &b1)) in a0.iter().zip(a1).zip(b0.iter().zip(b1)) { + acc0 += a0 * b0; + acc2 += (a1 - a0) * (b1 - b0); } - result - } - let mut res = sum_columns(dense_polys); - for (k, v) in sparse_polys.iter() { - res[*k] += v; + (acc0, acc2) } - res -} -// [CBBZ23] hyperplonk optimization -/// Accumulate eq polynomial evaluations at binary query points into a sparse map. -/// Skips indices `0..=s`. -pub fn accumulate_sparse_evaluations( - zetas: Vec<&[F]>, - eq_evals: Vec, - s: usize, - r: usize, -) -> FastMap { - let mut result = FastMap::default(); - for i in 1 + s..r { - let index = zetas[i] - .iter() - .enumerate() - .filter_map(|(j, bit)| bit.is_one().then_some(1 << j)) - .sum::(); - *result.entry(index).or_insert(F::zero()) += &eq_evals[i]; + let non_padded = a.len().min(b.len()); + let a = &a[..non_padded]; + let b = &b[..non_padded]; + if a.is_empty() { + return (F::ZERO, F::ZERO); } - result + if a.len() == 1 { + return (a[0] * b[0], F::ZERO); + } + + let half = a.len().next_power_of_two() >> 1; + let (a0, a1) = a.split_at(half); + let (b0, b1) = b.split_at(half); + debug_assert!(a0.len() >= a1.len()); + let (a0, a0_tail) = a0.split_at(a1.len()); + let (b0, b0_tail) = b0.split_at(a1.len()); + let (acc0, acc2) = recurse(a0, a1, b0, b1); + + // Tail (a1, b1 = implicit zero padding): both contributions collapse to a0·b0. + let acc = dot(a0_tail, b0_tail); + (acc0 + acc, acc2 + acc) } -/// Run the inner product sumcheck protocol over two evaluation vectors, -/// using a generic [`Transcript`] for Fiat-Shamir (or sanity/random challenges). -/// -/// `BF` is the base field of the evaluations, `EF` is the extension field for challenges. -/// When `BF = EF`, this is the standard single-field inner product sumcheck. -/// When `BF ≠ EF`, round 0 evaluates in `BF` and lifts to `EF`, then subsequent -/// rounds work entirely in `EF`. +/// In-place half-split fold: `new[k] = v[k] + (v[k+L/2] − v[k]) · weight`. /// -/// Each round: -/// 1. Computes the round polynomial evaluations `(s(0), s(1), s(2))` via the product prover. -/// 2. Writes them to the transcript (3 field elements). -/// 3. Reads the verifier's challenge from the transcript (1 field element). -/// 4. Reduces both evaluation vectors by folding with the challenge. -pub fn inner_product_sumcheck>( - f: &mut [BF], - g: &mut [BF], - transcript: &mut impl Transcript, -) -> ProductSumcheck { - // checks - assert_eq!(f.len(), g.len()); - assert!(f.len().count_ones() == 1); - - let num_rounds = f.len().trailing_zeros() as usize; - let mut prover_messages: Vec<(EF, EF, EF)> = vec![]; - let mut verifier_messages: Vec = vec![]; - - // ── Round 0: evaluate in BF, lift to EF, cross-field reduce ── - if num_rounds > 0 { - let mut prover = TimeProductProver::new(TimeProductProverConfig::new( - f.len().trailing_zeros() as usize, - vec![MemoryStream::new(f.to_vec()), MemoryStream::new(g.to_vec())], - ReduceMode::Pairwise, - )); - - let msg_bf = prover.next_message(None).unwrap(); - let msg = (EF::from(msg_bf.0), EF::from(msg_bf.1), EF::from(msg_bf.2)); - - prover_messages.push(msg); - transcript.write(msg.0); - transcript.write(msg.1); - transcript.write(msg.2); - - let chg = transcript.read(); - verifier_messages.push(chg); - - // Cross-field reduce: BF evaluations + EF challenge → Vec - let mut ef_f = pairwise::cross_field_reduce(f, chg); - let mut ef_g = pairwise::cross_field_reduce(g, chg); - - // Remaining rounds work in EF - for _ in 1..num_rounds { - let mut prover = TimeProductProver::new(TimeProductProverConfig::new( - ef_f.len().trailing_zeros() as usize, - vec![ - MemoryStream::new(ef_f.to_vec()), - MemoryStream::new(ef_g.to_vec()), - ], - ReduceMode::Pairwise, - )); - - let msg = prover.next_message(None).unwrap(); - - prover_messages.push(msg); - transcript.write(msg.0); - transcript.write(msg.1); - transcript.write(msg.2); - - let chg = transcript.read(); - verifier_messages.push(chg); - - pairwise::reduce_evaluations(&mut ef_f, chg); - pairwise::reduce_evaluations(&mut ef_g, chg); +/// `values` is implicitly zero-padded to the next power of two. On return, +/// the length is a power of two (or zero). +pub fn fold(values: &mut Vec, weight: F) { + fn recurse_both(low: &mut [F], high: &[F], weight: F) { + #[cfg(feature = "parallel")] + if low.len() > workload_size::() { + let split = low.len() / 2; + let (ll, lr) = low.split_at_mut(split); + let (hl, hr) = high.split_at(split); + join( + || recurse_both(ll, hl, weight), + || recurse_both(lr, hr, weight), + ); + return; + } + for (low, high) in low.iter_mut().zip(high) { + *low += (*high - *low) * weight; } } - ProductSumcheck { - verifier_messages, - prover_messages, + if values.len() <= 1 { + return; } -} - -#[cfg(test)] -mod tests { - use super::*; - use ark_ff::UniformRand; - use ark_std::test_rng; - use crate::tests::F64; + let half = values.len().next_power_of_two() >> 1; + let (low, high) = values.split_at_mut(half); + debug_assert!(low.len() >= high.len()); + let (low, tail) = low.split_at_mut(high.len()); + recurse_both(low, high, weight); - const NUM_VARS: usize = 4; // vectors of length 2^4 = 16 + // Tail with implicit zero high: *low *= 1 − weight. + scalar_mul(tail, F::ONE - weight); - #[test] - fn test_inner_product_sumcheck_sanity() { - use crate::transcript::SanityTranscript; + values.truncate(half); + values.shrink_to_fit(); +} - let mut rng = test_rng(); +/// Two-pass fold-then-compute; reference version kept for testing. +pub fn fold_and_compute_polynomial(a: &mut Vec, b: &mut Vec, weight: F) -> (F, F) { + fold(a, weight); + fold(b, weight); + compute_sumcheck_polynomial(a, b) +} - let n = 1 << NUM_VARS; - let mut f: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); - let mut g: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); +/// Fused single-pass variant. +/// +/// Folds `a` and `b` by `weight` *and* computes the next-round polynomial +/// `(c0, c2)` in one sweep. The fold writes `[0, L/2)`; the subsequent +/// compute splits the length-`L/2` folded vector at `L/4`. So every +/// quadruple `(x[k], x[k+L/4], x[k+L/2], x[k+3L/4])` is touched exactly +/// once — 8 reads + 4 writes (fused) vs. 12 reads + 4 writes (unfused). +/// +/// Falls back to the unfused path for small or non-pow2 inputs so the +/// implicit-zero tail accounting stays identical. +pub fn fused_fold_and_compute_polynomial( + a: &mut Vec, + b: &mut Vec, + weight: F, +) -> (F, F) { + let l = a.len(); + debug_assert_eq!(l, b.len()); + if !l.is_power_of_two() || l < 4 { + return fold_and_compute_polynomial(a, b, weight); + } - let mut transcript = SanityTranscript::new(&mut rng); - let result = inner_product_sumcheck::(&mut f, &mut g, &mut transcript); + #[allow(clippy::too_many_arguments)] + fn kernel( + a0: &mut [F], + a1: &mut [F], + a2: &[F], + a3: &[F], + b0: &mut [F], + b1: &mut [F], + b2: &[F], + b3: &[F], + weight: F, + ) -> (F, F) { + debug_assert_eq!(a0.len(), a1.len()); + debug_assert_eq!(a0.len(), a2.len()); + debug_assert_eq!(a0.len(), a3.len()); + debug_assert_eq!(a0.len(), b0.len()); + debug_assert_eq!(a0.len(), b1.len()); + debug_assert_eq!(a0.len(), b2.len()); + debug_assert_eq!(a0.len(), b3.len()); + + #[cfg(feature = "parallel")] + if a0.len() * 4 > workload_size::() { + let mid = a0.len() / 2; + let (a0l, a0r) = a0.split_at_mut(mid); + let (a1l, a1r) = a1.split_at_mut(mid); + let (a2l, a2r) = a2.split_at(mid); + let (a3l, a3r) = a3.split_at(mid); + let (b0l, b0r) = b0.split_at_mut(mid); + let (b1l, b1r) = b1.split_at_mut(mid); + let (b2l, b2r) = b2.split_at(mid); + let (b3l, b3r) = b3.split_at(mid); + let (left, right) = join( + || kernel(a0l, a1l, a2l, a3l, b0l, b1l, b2l, b3l, weight), + || kernel(a0r, a1r, a2r, a3r, b0r, b1r, b2r, b3r, weight), + ); + return (left.0 + right.0, left.1 + right.1); + } - assert_eq!(result.prover_messages.len(), NUM_VARS); - assert_eq!(result.verifier_messages.len(), NUM_VARS); + let mut c0 = F::ZERO; + let mut c2 = F::ZERO; + for i in 0..a0.len() { + let x0 = a0[i]; + let x1 = a1[i]; + let x2 = a2[i]; + let x3 = a3[i]; + let y0 = b0[i]; + let y1 = b1[i]; + let y2 = b2[i]; + let y3 = b3[i]; + + let na_lo = x0 + (x2 - x0) * weight; + let na_hi = x1 + (x3 - x1) * weight; + let nb_lo = y0 + (y2 - y0) * weight; + let nb_hi = y1 + (y3 - y1) * weight; + + a0[i] = na_lo; + a1[i] = na_hi; + b0[i] = nb_lo; + b1[i] = nb_hi; + + c0 += na_lo * nb_lo; + c2 += (na_hi - na_lo) * (nb_hi - nb_lo); + } + (c0, c2) } - #[test] - fn test_inner_product_sumcheck_spongefish() { - use crate::transcript::SpongefishTranscript; + let quarter = l / 4; + let half = l / 2; - let mut rng = test_rng(); + let (a_first, a_second) = a.split_at_mut(half); + let (a0, a1) = a_first.split_at_mut(quarter); + let (a2, a3) = a_second.split_at(quarter); + let (b_first, b_second) = b.split_at_mut(half); + let (b0, b1) = b_first.split_at_mut(quarter); + let (b2, b3) = b_second.split_at(quarter); - let n = 1 << NUM_VARS; - let mut f: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); - let mut g: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let result = kernel(a0, a1, a2, a3, b0, b1, b2, b3, weight); - let domsep = spongefish::domain_separator!("test-inner-product-sumcheck"; module_path!()) - .instance(b"test"); + a.truncate(half); + b.truncate(half); + // Skip shrink_to_fit — realloc per round is pricier than the capacity + // we carry; the capacity frees once the Vec drops. + result +} - let prover_state = domsep.std_prover(); - let mut transcript = SpongefishTranscript::new(prover_state); - let result = inner_product_sumcheck::(&mut f, &mut g, &mut transcript); +// ─── Prover ───────────────────────────────────────────────────────────────── - assert_eq!(result.prover_messages.len(), NUM_VARS); - assert_eq!(result.verifier_messages.len(), NUM_VARS); +/// Runs `num_rounds` rounds on `(a, b)`, folding both in place. +/// +/// Transcript per round: writes `c0` then `c2` (difference form), invokes +/// `hook(round, transcript)`, then reads the verifier challenge. +/// +/// On return, if `num_rounds == log2(next_pow2(len))` then `a` and `b` have +/// length 1 and `final_evaluations = (a[0], b[0])`; otherwise +/// `(F::ZERO, F::ZERO)`. +pub fn inner_product_sumcheck_partial( + a: &mut Vec, + b: &mut Vec, + transcript: &mut T, + num_rounds: usize, + mut hook: H, +) -> ProductSumcheck +where + F: Field, + T: ProverTranscript, + H: FnMut(usize, &mut T), +{ + assert_eq!(a.len(), b.len()); + assert!( + num_rounds == 0 || a.len().next_power_of_two() >= 1 << num_rounds, + "num_rounds ({num_rounds}) exceeds log2 of next-pow2 of len ({})", + a.len(), + ); + + let mut prover_messages: Vec<(F, F)> = Vec::with_capacity(num_rounds); + let mut verifier_messages: Vec = Vec::with_capacity(num_rounds); + let mut folding_randomness: Option = None; + + for round in 0..num_rounds { + // Staggered: round-(i-1) fold is fused into round-i compute. + let (c0, c2) = if let Some(w) = folding_randomness { + fused_fold_and_compute_polynomial(a, b, w) + } else { + compute_sumcheck_polynomial(a, b) + }; + + prover_messages.push((c0, c2)); + transcript.send(c0); + transcript.send(c2); + + hook(round, transcript); + + let r = transcript.challenge(); + verifier_messages.push(r); + folding_randomness = Some(r); } + + if let Some(w) = folding_randomness { + fold(a, w); + fold(b, w); + } + + let final_evaluations = if a.len() == 1 { + (a[0], b[0]) + } else { + (F::ZERO, F::ZERO) + }; + + ProductSumcheck { + prover_messages, + verifier_messages, + final_evaluations, + } +} + +/// Full sumcheck (`log2(next_pow2(len))` rounds) with a per-round hook. +pub fn inner_product_sumcheck( + a: &mut Vec, + b: &mut Vec, + transcript: &mut T, + hook: H, +) -> ProductSumcheck +where + F: Field, + T: ProverTranscript, + H: FnMut(usize, &mut T), +{ + let num_rounds = if a.is_empty() { + 0 + } else { + a.len().next_power_of_two().trailing_zeros() as usize + }; + inner_product_sumcheck_partial(a, b, transcript, num_rounds, hook) } + +// ─── Verifier ─────────────────────────────────────────────────────────────── + +// Tests live in `tests/inner_product_sumcheck.rs` (integration target) — +// the lib-test target is blocked by unrelated modules with stale +// `domain_separator!` syntax. diff --git a/src/interpolation/lagrange_polynomial.rs b/src/interpolation/lagrange_polynomial.rs deleted file mode 100644 index 69acbd06..00000000 --- a/src/interpolation/lagrange_polynomial.rs +++ /dev/null @@ -1,214 +0,0 @@ -use crate::{ - hypercube::{Hypercube, HypercubeMember}, - messages::VerifierMessages, - order_strategy::{GraycodeOrder, MSBOrder, OrderStrategy}, -}; -use ark_ff::Field; - -#[derive(Debug)] -pub struct LagrangePolynomial<'a, F: Field, O: OrderStrategy> { - order: O, - last_position: usize, - position: usize, - value: F, - verifier_messages: &'a VerifierMessages, - stop_position: usize, -} - -impl<'a, F: Field, O: OrderStrategy> LagrangePolynomial<'a, F, O> { - pub fn new(verifier_messages: &'a VerifierMessages) -> Self { - let num_vars = verifier_messages.messages.len(); - let order = O::new(num_vars); - Self { - order, - last_position: 0, - position: 0, - value: verifier_messages.product_of_message_hats, - verifier_messages, - stop_position: Hypercube::::stop_value(num_vars), - } - } - pub fn lag_poly(x: Vec, x_hat: Vec, b: HypercubeMember) -> F { - // Iterate over the zipped triple x, x_hat, and boolean hypercube vectors - x.iter().zip(x_hat.iter()).zip(b).fold( - // Initial the accumulation to F::ONE - F::ONE, - // Closure for the folding operation, taking accumulator and ((x_i, x_hat_i), b_i) - |acc, ((x_i, x_hat_i), b_i)| { - // Multiply the accumulator by either x_i or x_hat_i based on the boolean value b_i - acc * match b_i { - true => *x_i, - false => *x_hat_i, - } - }, - ) - } - pub fn evaluate_from_three_points(verifier_message: F, prover_message: (F, F, F)) -> F { - // Hardcoded x-values: - let zero = F::zero(); - let one = F::one(); - let half = F::from(2_u32).inverse().unwrap(); - - // Compute denominators for the Lagrange basis polynomials - let inv_denom_0 = ((zero - one) * (zero - half)).inverse().unwrap(); - let inv_denom_1 = ((one - zero) * (one - half)).inverse().unwrap(); - let inv_denom_2 = ((half - zero) * (half - one)).inverse().unwrap(); - - // Compute the Lagrange basis polynomials evaluated at x - let basis_p_0 = (verifier_message - one) * (verifier_message - half) * inv_denom_0; - let basis_p_1 = (verifier_message - zero) * (verifier_message - half) * inv_denom_1; - let basis_p_2 = (verifier_message - zero) * (verifier_message - one) * inv_denom_2; - - // Return the evaluation of the unique quadratic polynomial - prover_message.0 * basis_p_0 + prover_message.1 * basis_p_1 + prover_message.2 * basis_p_2 - } -} - -impl<'a, F: Field> Iterator for LagrangePolynomial<'a, F, GraycodeOrder> { - type Item = F; - fn next(&mut self) -> Option { - // Step 1: check if finished iterating - if self.position >= self.stop_position { - return None; - } - - // Step 2: check if this iteration yields zero, in which case we skip processing - let bit_agreement = !(self.verifier_messages.messages_zeros_and_ones_usize ^ self.position); - if bit_agreement & self.verifier_messages.zero_ones_mask - != self.verifier_messages.zero_ones_mask - { - // NOTICE! we do not update last_position in this case - self.position = GraycodeOrder::next_gray_code(self.position); - return Some(F::ZERO); - } - - // Step 3: check if position is 0, which is a special case - // Notice! step 2 could apply when position == 0 - if self.position == 0 { - self.position = GraycodeOrder::next_gray_code(self.position); - return Some(self.value); - } - - // Step 4: update the value, skip if more than one bit difference - let bit_diff = self.last_position ^ self.position; - if bit_diff.count_ones() == 1 { - let index_of_flipped_bit = bit_diff.trailing_zeros() as usize; - let is_flipped_to_true = self.position & bit_diff != 0; - let len = self.verifier_messages.messages.len(); - self.value *= match is_flipped_to_true { - true => { - self.verifier_messages.message_and_message_hat_inverses - [len - index_of_flipped_bit - 1] - } - false => { - self.verifier_messages.message_hat_and_message_inverses - [len - index_of_flipped_bit - 1] - } - }; - } - - // Step 5: increment positions - self.last_position = self.position; - self.position = GraycodeOrder::next_gray_code(self.position); - - // Step 6: return - Some(self.value) - } -} - -impl<'a, F: Field> Iterator for LagrangePolynomial<'a, F, MSBOrder> { - type Item = F; - fn next(&mut self) -> Option { - // Step 1: check if finished iterating - if self.position >= self.stop_position { - return None; - } - - // Step 2: check if this iteration yields zero, in which case we skip processing - let bit_agreement = !(self.verifier_messages.messages_zeros_and_ones_usize ^ self.position); - if bit_agreement & self.verifier_messages.zero_ones_mask - != self.verifier_messages.zero_ones_mask - { - // NOTICE! we do not update last_position in this case - self.position = - MSBOrder::next_value_in_msb_order(self.position, self.order.num_vars() as u32); - return Some(F::ZERO); - } - // Step 3: check if position is 0, which is a special case - // Notice! step 2 could apply when position == 0 - if self.position == 0 { - self.position = - MSBOrder::next_value_in_msb_order(self.position, self.order.num_vars() as u32); - return Some(self.value); - } - // Step 3: update the value - let len = self.verifier_messages.messages.len(); - for i in (0..len).rev() { - if self.position >> i == 0 { - self.value *= self.verifier_messages.message_hat_and_message_inverses[len - i - 1]; - } else { - self.value *= self.verifier_messages.message_and_message_hat_inverses[len - i - 1]; - break; - } - } - - // Step 5: increment positions - self.last_position = self.position; - self.position = - MSBOrder::next_value_in_msb_order(self.position, self.order.num_vars() as u32); - - // Step 6: return - Some(self.value) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - hypercube::HypercubeMember, interpolation::LagrangePolynomial, messages::VerifierMessages, - order_strategy::GraycodeOrder, tests::F19, - }; - - #[test] - fn next() { - // remember this is gray code ordering! - let messages: Vec = vec![F19::from(13), F19::from(0), F19::from(7)]; - let message_hats: Vec = messages - .clone() - .iter() - .map(|message| F19::from(1) - message) - .collect(); - let vm = VerifierMessages::new(&vec![F19::from(13), F19::from(0), F19::from(7)]); - let mut lag_poly: LagrangePolynomial = LagrangePolynomial::new(&vm); - for gray_code_index in [0, 1, 3, 2, 6, 7, 5, 4] { - let exp = LagrangePolynomial::::lag_poly( - messages.clone(), - message_hats.clone(), - HypercubeMember::new(3, gray_code_index), - ); - assert_eq!(lag_poly.next().unwrap(), exp); - } - assert_eq!(lag_poly.next(), None); - } - #[test] - fn boolean_next() { - // remember this is gray code ordering! - let messages: Vec = vec![F19::from(0), F19::from(1), F19::from(1)]; - let message_hats: Vec = messages - .clone() - .iter() - .map(|message| F19::from(1) - message) - .collect(); - let vm = VerifierMessages::new(&vec![F19::from(0), F19::from(1), F19::from(1)]); - let mut lag_poly: LagrangePolynomial = LagrangePolynomial::new(&vm); - for gray_code_index in [0, 1, 3, 2, 6, 7, 5, 4] { - let exp = LagrangePolynomial::::lag_poly( - messages.clone(), - message_hats.clone(), - HypercubeMember::new(3, gray_code_index), - ); - assert_eq!(lag_poly.next().unwrap(), exp); - } - assert_eq!(lag_poly.next(), None); - } -} diff --git a/src/interpolation/mod.rs b/src/interpolation/mod.rs deleted file mode 100644 index bb0657e8..00000000 --- a/src/interpolation/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod lagrange_polynomial; -pub use lagrange_polynomial::LagrangePolynomial; diff --git a/src/lib.rs b/src/lib.rs index 0ee4112a..0c73fdc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,56 +1,80 @@ -//! # efficient-sumcheck +//! # effsc //! -//! Space-efficient implementations of the sumcheck protocol with Fiat-Shamir support. -//! -//! ## Quick Start -//! -//! For most use cases, you need just two functions and a transcript: -//! -//! ```text -//! use efficient_sumcheck::{multilinear_sumcheck, inner_product_sumcheck}; -//! use efficient_sumcheck::transcript::{Transcript, SpongefishTranscript, SanityTranscript}; -//! ``` -//! -//! - [`multilinear_sumcheck()`] — standard multilinear sumcheck: `∑_x p(x)` -//! - [`inner_product_sumcheck()`] — inner product sumcheck: `∑_x f(x)·g(x)` -//! -//! Both accept any [`Transcript`] implementation — either -//! [`SpongefishTranscript`](transcript::SpongefishTranscript) for real Fiat-Shamir, or -//! [`SanityTranscript`](transcript::SanityTranscript) for testing with random challenges. -//! -//! ## Advanced Usage -//! -//! For custom prover implementations, streaming evaluation access, -//! or specialized reduction strategies, the internal modules expose the full -//! prover machinery: [`multilinear`], [`multilinear_product`], [`prover`], [`streams`]. +//! Sumcheck protocol (Thaler Proposition 4.1) with SIMD acceleration. + +#![cfg_attr(not(feature = "arkworks"), no_std)] + +extern crate alloc; + +// ─── Generic field trait ───────────────────────────────────────────────────── + +pub mod field; +pub mod proof; + +// ─── New canonical API (Thaler §4.1) ──────────────────────────────────────── + +pub mod fold; +pub mod polynomial; +pub mod provers; +pub mod runner; +pub mod sumcheck_prover; +pub mod verifier; -// ─── Primary API ───────────────────────────────────────────────────────────── +/// No-op per-round hook for the prover. Pass to `sumcheck()` when no hook is needed. +/// +/// ```ignore +/// let proof = sumcheck(&mut prover, n, &mut t, noop_hook); +/// ``` +pub fn noop_hook(_round: usize, _transcript: &mut T) {} + +/// No-op per-round hook for the verifier. Pass to `sumcheck_verify()` when no hook is needed. +/// +/// ```ignore +/// let result = sumcheck_verify(sum, deg, n, &mut t, noop_hook_verify)?; +/// ``` +pub fn noop_hook_verify( + _round: usize, + _transcript: &mut T, +) -> Result<(), crate::proof::SumcheckError> { + Ok(()) +} + +// ─── Transcript ───────────────────────────────────────────────────────────── -/// Transcript trait and backends (Spongefish, Sanity). pub mod transcript; +// ─── Arkworks-dependent modules ───────────────────────────────────────────── + +#[cfg(feature = "arkworks")] mod inner_product_sumcheck; +#[cfg(feature = "arkworks")] mod multilinear_sumcheck; +#[cfg(feature = "arkworks")] pub use inner_product_sumcheck::{ - accumulate_sparse_evaluations, batched_constraint_poly, inner_product_sumcheck, ProductSumcheck, + inner_product_sumcheck, inner_product_sumcheck_partial, ProductSumcheck, +}; +#[cfg(feature = "arkworks")] +pub use multilinear_sumcheck::{ + compute_sumcheck_polynomial, fold, fused_fold_and_compute_polynomial, multilinear_sumcheck, + multilinear_sumcheck_partial, Sumcheck, }; -pub use multilinear_sumcheck::{multilinear_sumcheck, Sumcheck}; - -// ─── Internal / Advanced ───────────────────────────────────────────────────── - -pub mod multilinear; -pub mod multilinear_product; -pub mod prover; -pub mod streams; - -pub mod hypercube; -pub mod interpolation; -pub mod messages; -pub mod order_strategy; +#[cfg(feature = "arkworks")] pub mod coefficient_sumcheck; +#[cfg(feature = "arkworks")] pub mod folding; - +pub mod hypercube; +#[cfg(feature = "arkworks")] +pub mod poly_ops; +#[cfg(feature = "arkworks")] +pub(crate) mod reductions; +#[cfg(all(feature = "arkworks", feature = "simd"))] +pub(crate) mod simd_fields; +#[cfg(all(feature = "arkworks", feature = "simd"))] +pub(crate) mod simd_sumcheck; +#[cfg(feature = "arkworks")] +pub mod streams; +#[cfg(feature = "arkworks")] #[doc(hidden)] pub mod tests; diff --git a/src/messages/mod.rs b/src/messages/mod.rs deleted file mode 100644 index 15f15a55..00000000 --- a/src/messages/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod verifier_messages; -pub use verifier_messages::VerifierMessages; diff --git a/src/messages/verifier_messages.rs b/src/messages/verifier_messages.rs deleted file mode 100644 index d10fa738..00000000 --- a/src/messages/verifier_messages.rs +++ /dev/null @@ -1,130 +0,0 @@ -use ark_ff::Field; - -#[derive(Clone, Debug)] -pub struct VerifierMessages { - pub messages: Vec, - pub message_hats: Vec, - pub message_and_message_hat_inverses: Vec, - pub message_hat_and_message_inverses: Vec, - pub messages_zeros_and_ones_usize: usize, - pub zero_ones_mask: usize, - pub product_of_message_hats: F, -} - -impl VerifierMessages { - pub fn new(messages: &Vec) -> Self { - let mut verifier_messages = Self { - messages: vec![], - message_hats: vec![], - product_of_message_hats: F::ONE, - message_and_message_hat_inverses: vec![], - message_hat_and_message_inverses: vec![], - messages_zeros_and_ones_usize: 0, - zero_ones_mask: 0, - }; - for message in messages { - verifier_messages.receive_message(*message); - } - verifier_messages - } - pub fn new_from_self(vm: &Self, start: usize, end: usize) -> Self { - // TODO (z-tech): this can be redone more efficiently - Self::new(&vm.messages[start..end].to_vec()) - } - pub fn receive_message(&mut self, message: F) { - // Step 1: compute some things - let message_hat = F::ONE - message; - let message_inverse = match message.inverse() { - Some(inverse) => inverse, - None => F::ONE, - }; - let message_hat_inverse = match message_hat.inverse() { - Some(inverse) => inverse, - None => F::ONE, - }; - // Step 2: store some things - self.messages.push(message); - self.message_hats.push(message_hat); - self.message_and_message_hat_inverses - .push(message * message_hat_inverse); - self.message_hat_and_message_inverses - .push(message_hat * message_inverse); - - if message == F::ZERO || message_hat == F::ZERO { - self.zero_ones_mask = (self.zero_ones_mask << 1) | 1; - self.messages_zeros_and_ones_usize = if message == F::ONE { - self.messages_zeros_and_ones_usize << 1 | 1 - } else { - self.messages_zeros_and_ones_usize << 1 - }; - } else { - self.zero_ones_mask <<= 1; - self.messages_zeros_and_ones_usize <<= 1; - self.product_of_message_hats *= message_hat; - } - } -} - -#[cfg(test)] -mod tests { - use crate::{messages::VerifierMessages, tests::F19}; - use ark_ff::{One, Zero}; - - #[test] - fn receive_message() { - let mut m0 = VerifierMessages::new(&vec![]); - - // ## receive 13 - m0.receive_message(F19::from(13)); - assert_eq!(m0.messages, vec![F19::from(13)]); - assert_eq!(m0.message_hats, vec![F19::one() - F19::from(13)]); - - // ## receive 0 - m0.receive_message(F19::zero()); - assert_eq!(m0.messages, vec![F19::from(13), F19::zero()]); - assert_eq!( - m0.message_hats, - vec![F19::one() - F19::from(13), F19::one()] - ); - - // ## receive 7 - m0.receive_message(F19::from(7)); - assert_eq!(m0.messages, vec![F19::from(13), F19::zero(), F19::from(7)]); - assert_eq!( - m0.message_hats, - vec![ - F19::one() - F19::from(13), - F19::one(), - F19::one() - F19::from(7) - ] - ); - - // ## receive 1 - m0.receive_message(F19::one()); - assert_eq!( - m0.messages, - vec![F19::from(13), F19::zero(), F19::from(7), F19::one()] - ); - assert_eq!( - m0.message_hats, - vec![ - F19::one() - F19::from(13), - F19::one(), - F19::one() - F19::from(7), - F19::zero() - ] - ); - - let mut m1 = VerifierMessages::new(&vec![]); - - // ## receive zero - m1.receive_message(F19::from(0)); - assert_eq!(m1.messages, vec![F19::from(0)]); - assert_eq!(m1.message_hats, vec![F19::one()]); - - // receive 1 - m1.receive_message(F19::from(1)); - assert_eq!(m1.messages, vec![F19::from(0), F19::one()]); - assert_eq!(m1.message_hats, vec![F19::one(), F19::zero()]); - } -} diff --git a/src/multilinear/mod.rs b/src/multilinear/mod.rs deleted file mode 100644 index 10946249..00000000 --- a/src/multilinear/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod provers; -mod sumcheck; - -pub use provers::{ - blendy::{BlendyProver, BlendyProverConfig}, - space::{SpaceProver, SpaceProverConfig}, - time::{reductions, ReduceMode, TimeProver, TimeProverConfig}, -}; -pub use sumcheck::Sumcheck; diff --git a/src/multilinear/provers/blendy/config.rs b/src/multilinear/provers/blendy/config.rs deleted file mode 100644 index 4294cde1..00000000 --- a/src/multilinear/provers/blendy/config.rs +++ /dev/null @@ -1,41 +0,0 @@ -use ark_ff::Field; -use ark_std::marker::PhantomData; - -use crate::{prover::ProverConfig, streams::Stream}; - -pub struct BlendyProverConfig -where - F: Field, - S: Stream, -{ - pub num_stages: usize, - pub num_variables: usize, - pub stream: S, - _f: PhantomData, -} - -impl BlendyProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_stages: usize, num_variables: usize, stream: S) -> Self { - Self { - num_stages, - num_variables, - stream, - _f: PhantomData::, - } - } -} - -impl> ProverConfig for BlendyProverConfig { - fn default(num_variables: usize, stream: S) -> Self { - Self { - num_stages: 2, // DEFAULT - num_variables, - stream, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear/provers/blendy/core.rs b/src/multilinear/provers/blendy/core.rs deleted file mode 100644 index 6692ac9b..00000000 --- a/src/multilinear/provers/blendy/core.rs +++ /dev/null @@ -1,236 +0,0 @@ -use ark_ff::Field; -use ark_std::cfg_iter_mut; -use ark_std::{cfg_into_iter, vec::Vec}; - -use crate::{ - hypercube::Hypercube, interpolation::LagrangePolynomial, messages::VerifierMessages, - order_strategy::GraycodeOrder, streams::Stream, -}; -#[cfg(feature = "parallel")] -use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, -}; - -pub struct BlendyProver -where - F: Field, - S: Stream, -{ - pub current_round: usize, - pub evaluation_stream: S, - pub lag_polys: Vec, - pub lag_polys_update: Vec, - pub num_stages: usize, - pub num_variables: usize, - pub stage_size: usize, - pub sums: Vec, - pub verifier_messages: VerifierMessages, -} - -impl BlendyProver -where - F: Field, - S: Stream, -{ - fn shift_and_one_fill(num: usize, shift_amount: usize) -> usize { - (num << shift_amount) | ((1 << shift_amount) - 1) - } - - pub fn compute_round(&self, partial_sums: &[F]) -> (F, F) { - // Initialize accumulators for sum_0 and sum_1 - let mut sum_0 = F::ZERO; - let mut sum_1 = F::ZERO; - - // Calculate j_prime as j-(s-1)l - let stage_start_index: usize = self.current_stage() * self.stage_size; - let j_prime = self.current_round - stage_start_index; - - // Iterate through b2_start indices using Hypercube::new(j_prime + 1) - for (b2_start_index, _) in Hypercube::::new(j_prime + 1) { - // Calculate b2_start_index_0 and b2_start_index_1 for indexing partial_sums - let shift_amount = if self.num_variables - stage_start_index < self.stage_size { - // this is the oddly sized last stage when k doesn't divide num_vars - self.num_variables - (self.current_stage() * self.stage_size) - j_prime - 1 - } else { - self.stage_size - j_prime - 1 - }; - let b2_start_index_0: usize = b2_start_index << shift_amount; - let b2_start_index_1: usize = Self::shift_and_one_fill(b2_start_index, shift_amount); - - // Calculate left_value and right_value based on partial_sums - let left_value: F = match b2_start_index_0 { - 0 => F::ZERO, - _ => partial_sums[b2_start_index_0 - 1], - }; - let right_value = partial_sums[b2_start_index_1]; - let sum = right_value - left_value; - - // Match based on the last bit of b2_start - match b2_start_index & 1 == 1 { - false => sum_0 += self.lag_polys[b2_start_index] * sum, - true => sum_1 += self.lag_polys[b2_start_index] * sum, - } - } - - // Return the accumulated sums - (sum_0, sum_1) - } - - fn current_stage(&self) -> usize { - self.current_round / self.stage_size - } - - pub fn is_initial_round(&self) -> bool { - self.current_round == 0 - } - - pub fn is_start_of_stage(&self) -> bool { - self.current_round.is_multiple_of(self.stage_size) - } - - fn is_single_staged(&self) -> bool { - self.num_stages == 1 - } - - pub fn sum_update(&mut self) { - if self.is_single_staged() { - return; - }; - // 0. Declare ranges for convenience - let b1_num_vars: usize = self.current_stage() * self.stage_size; - let b2_num_vars: usize = if self.num_variables - b1_num_vars < self.stage_size { - // this is the oddly sized last stage when k doesn't divide num_vars - self.num_variables - b1_num_vars - } else { - self.stage_size - }; - let b3_num_vars: usize = self.num_variables - b1_num_vars - b2_num_vars; - - // 1. Initialize SUM[b2] := 0 for each b2 ∈ {0,1}^l - // we reuse self.sums we just have to zero out on the first access SEE BELOW - - // 2. Initialize st := LagInit((s - l)l, r) - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages); - - // 3. For each b1 ∈ {0,1}^(s-1)l - let len_sums: usize = self.sums.len(); - for (b1_index, _) in Hypercube::::new(b1_num_vars) { - // (a) Compute (LagPoly, st) := LagNext(st) - let lag_poly = sequential_lag_poly.next().unwrap(); - - // (b) For each b2 ∈ {0,1}^l, for each b2 ∈ {0,1}^(k-s)l - for (b2_index, _) in Hypercube::::new(b2_num_vars) { - for (b3_index, _) in Hypercube::::new(b3_num_vars) { - // Calculate the index for the current combination of b1, b2, and b3 - let index = b1_index << (b2_num_vars + b3_num_vars) - | b2_index << b3_num_vars - | b3_index; - - // Update SUM[b2] - self.sums[b2_index] = - match b1_index == 0 && b3_index == 0 && b2_index < len_sums { - // SEE HERE zero out the array on first access per update - true => lag_poly * self.evaluation_stream.evaluation(index), - false => { - self.sums[b2_index] - + lag_poly * self.evaluation_stream.evaluation(index) - } - }; - } - } - } - } - pub fn update_lag_polys(&mut self) { - // Calculate j_prime as j-(s-1)l - let j_prime = self.current_round - (self.current_stage() * self.stage_size); - - // Iterate through b2_start indices using Hypercube::new(j_prime + 1) - for (b2_start_index, _) in Hypercube::::new(j_prime + 1) { - // calculate lag_poly from precomputed - let lag_poly = match j_prime { - 0 => F::ONE, - _ => { - let precomputed: F = *self.lag_polys.get(b2_start_index >> 1).unwrap(); - match b2_start_index & 2 == 2 { - true => precomputed * *self.verifier_messages.messages.last().unwrap(), - false => precomputed * *self.verifier_messages.message_hats.last().unwrap(), - } - } - }; - self.lag_polys_update[b2_start_index] = lag_poly; - } - std::mem::swap(&mut self.lag_polys, &mut self.lag_polys_update); - } - - pub fn update_prefix_sums(&mut self) { - let n = self.sums.len(); - if n == 0 { - return; - } - - // Step 0: Unified input vector - let input: Vec = if self.is_single_staged() { - (0..n) - .map(|i| self.evaluation_stream.evaluation(i)) - .collect() - } else { - self.sums.clone() - }; - - // Step 1: Determine chunk size and boundaries - #[cfg(feature = "parallel")] - let num_threads = rayon::current_num_threads().max(1); - #[cfg(not(feature = "parallel"))] - let num_threads = 1; - let chunk_size: usize = (n / num_threads).max(1); - - // Compute chunk start indices - let chunk_starts: Vec = (0..n).step_by(chunk_size).collect(); - - // Step 2: Parallel local prefix sums - let mut partial_sums: Vec> = cfg_into_iter!(chunk_starts.clone()) - .map(|start: usize| { - let end: usize = (start + chunk_size).min(n); - let mut local_sum = F::ZERO; - let mut out = Vec::with_capacity(end - start); - for &x in &input[start..end] { - local_sum += x; - out.push(local_sum); - } - out - }) - .collect(); - - // Step 3: Collect per-chunk totals - let mut chunk_totals: Vec = partial_sums - .iter() - .map(|chunk| *chunk.last().unwrap()) - .collect(); - - // Step 4: Exclusive prefix sum of chunk totals (serial) - for i in 1..chunk_totals.len() { - let prev = chunk_totals[i - 1]; - chunk_totals[i] += prev; - } - - // Step 5: Offset adjust each chunk in parallel - cfg_iter_mut!(partial_sums) - .enumerate() - .for_each(|(i, chunk)| { - let offset = if i == 0 { F::ZERO } else { chunk_totals[i - 1] }; - if offset != F::ZERO { - for x in chunk.iter_mut() { - *x += offset; - } - } - }); - - // Step 6: Flatten into final result - self.sums = partial_sums.into_iter().flatten().collect(); - } - - pub fn total_rounds(&self) -> usize { - self.num_variables - } -} diff --git a/src/multilinear/provers/blendy/mod.rs b/src/multilinear/provers/blendy/mod.rs deleted file mode 100644 index f199125c..00000000 --- a/src/multilinear/provers/blendy/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::BlendyProverConfig; -pub use core::BlendyProver; diff --git a/src/multilinear/provers/blendy/prover.rs b/src/multilinear/provers/blendy/prover.rs deleted file mode 100644 index 59655ff1..00000000 --- a/src/multilinear/provers/blendy/prover.rs +++ /dev/null @@ -1,78 +0,0 @@ -use ark_ff::Field; - -use crate::{ - hypercube::Hypercube, - messages::VerifierMessages, - multilinear::{BlendyProver, BlendyProverConfig}, - order_strategy::GraycodeOrder, - prover::Prover, - streams::Stream, -}; - -impl Prover for BlendyProver -where - F: Field, - S: Stream, -{ - type ProverConfig = BlendyProverConfig; - type ProverMessage = Option<(F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - let stage_size: usize = prover_config.num_variables / prover_config.num_stages; - Self { - current_round: 0, - evaluation_stream: prover_config.stream, - num_stages: prover_config.num_stages, - num_variables: prover_config.num_variables, - verifier_messages: VerifierMessages::new(&vec![]), - sums: vec![F::ZERO; Hypercube::::stop_value(stage_size)], - lag_polys: vec![F::ONE; Hypercube::::stop_value(stage_size)], - lag_polys_update: vec![F::ONE; Hypercube::::stop_value(stage_size)], - stage_size, - } - } - - fn next_message(&mut self, verifier_message: Option) -> Option<(F, F)> { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if !self.is_initial_round() { - self.verifier_messages - .receive_message(verifier_message.unwrap()); - } - - // at start of stage do some stuff - if self.is_start_of_stage() { - self.sum_update(); - self.update_prefix_sums(); - } - - // update lag_polys based on previous round - self.update_lag_polys(); - - let sums: (F, F) = self.compute_round(&self.sums); - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial sums - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear::BlendyProver, - streams::MemoryStream, - tests::{multilinear::sanity_test, F19}, - }; - - #[test] - fn sumcheck() { - sanity_test::, BlendyProver>>(); - } -} diff --git a/src/multilinear/provers/mod.rs b/src/multilinear/provers/mod.rs deleted file mode 100644 index 7e2173bd..00000000 --- a/src/multilinear/provers/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod blendy; -pub mod space; -pub mod time; diff --git a/src/multilinear/provers/space/config.rs b/src/multilinear/provers/space/config.rs deleted file mode 100644 index 1537f165..00000000 --- a/src/multilinear/provers/space/config.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::marker::PhantomData; - -use ark_ff::Field; - -use crate::{ - prover::{BatchProverConfig, ProverConfig}, - streams::Stream, -}; - -pub struct SpaceProverConfig -where - F: Field, - S: Stream, -{ - pub num_variables: usize, - pub streams: Vec, - _f: PhantomData, -} - -impl SpaceProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_variables: usize, stream: S) -> Self { - Self { - num_variables, - streams: vec![stream], - _f: PhantomData::, - } - } -} - -impl> ProverConfig for SpaceProverConfig { - fn default(num_variables: usize, stream: S) -> Self { - Self { - num_variables, - streams: vec![stream], - _f: PhantomData::, - } - } -} - -impl> BatchProverConfig for SpaceProverConfig { - fn default(num_variables: usize, streams: Vec) -> Self { - Self { - num_variables, - streams, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear/provers/space/core.rs b/src/multilinear/provers/space/core.rs deleted file mode 100644 index ef8d16c3..00000000 --- a/src/multilinear/provers/space/core.rs +++ /dev/null @@ -1,71 +0,0 @@ -use ark_ff::Field; - -use crate::{ - hypercube::Hypercube, interpolation::LagrangePolynomial, order_strategy::GraycodeOrder, - streams::Stream, -}; - -pub struct SpaceProver> { - pub current_round: usize, - pub evaluation_streams: Vec, - pub num_variables: usize, - pub verifier_messages: Vec, - pub verifier_message_hats: Vec, -} - -impl> SpaceProver { - pub fn cty_evaluate(&self) -> (F, F) { - // Initialize accumulators for sum_0 and sum_1 - let mut sum_0: F = F::ZERO; - let mut sum_1: F = F::ZERO; - - // Create a bitmask for the number of free variables - let bitmask: usize = 1 << (self.num_free_variables() - 1); - - // Iterate in two loops - let num_vars_outer_loop = self.current_round; - let num_vars_inner_loop = self.num_variables - num_vars_outer_loop; - - // Outer loop over a subset of variables - for (index_outer, outer) in Hypercube::::new(num_vars_outer_loop) { - // Calculate the weight using Lagrange polynomial - let lag_poly: F = LagrangePolynomial::::lag_poly( - self.verifier_messages.clone(), - self.verifier_message_hats.clone(), - outer, - ); - - if lag_poly == F::ZERO { - // in this case the inner loop does nothing - continue; - } - - // Inner loop over all possible evaluations for the remaining variables - for (index_inner, _inner) in Hypercube::::new(num_vars_inner_loop) { - // Calculate the evaluation index - let evaluation_index = index_outer << num_vars_inner_loop | index_inner; - - // Check if the bit at the position specified by the bitmask is set - let is_set: bool = (evaluation_index & bitmask) != 0; - - for stream in &self.evaluation_streams { - // Use match to accumulate the appropriate value based on whether the bit is set or not - let inner_sum = stream.evaluation(evaluation_index) * lag_poly; - match is_set { - false => sum_0 += inner_sum, - true => sum_1 += inner_sum, - } - } - } - } - - // Return the accumulated sums - (sum_0, sum_1) - } - pub fn num_free_variables(&self) -> usize { - self.num_variables - self.current_round - } - pub fn total_rounds(&self) -> usize { - self.num_variables - } -} diff --git a/src/multilinear/provers/space/mod.rs b/src/multilinear/provers/space/mod.rs deleted file mode 100644 index d2569104..00000000 --- a/src/multilinear/provers/space/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::SpaceProverConfig; -pub use core::SpaceProver; diff --git a/src/multilinear/provers/space/prover.rs b/src/multilinear/provers/space/prover.rs deleted file mode 100644 index 0f2ac3bb..00000000 --- a/src/multilinear/provers/space/prover.rs +++ /dev/null @@ -1,59 +0,0 @@ -use ark_ff::Field; - -use crate::{ - multilinear::{SpaceProver, SpaceProverConfig}, - prover::Prover, - streams::Stream, -}; - -impl> Prover for SpaceProver { - type ProverConfig = SpaceProverConfig; - type ProverMessage = Option<(F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - Self { - evaluation_streams: prover_config.streams, - verifier_messages: Vec::::with_capacity(prover_config.num_variables), - verifier_message_hats: Vec::::with_capacity(prover_config.num_variables), - current_round: 0, - num_variables: prover_config.num_variables, - } - } - - fn next_message(&mut self, verifier_message: Option) -> Option<(F, F)> { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - // If it's not the first round, add the verifier message to verifier_messages - if self.current_round != 0 { - self.verifier_messages.push(verifier_message.unwrap()); - self.verifier_message_hats - .push(F::ONE - verifier_message.unwrap()); - } - - // evaluate using cty - let sums: (F, F) = self.cty_evaluate(); - - // don't forget to increment the round - self.current_round += 1; - - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear::SpaceProver, - streams::MemoryStream, - tests::{multilinear::sanity_test, F19}, - }; - - #[test] - fn sumcheck() { - sanity_test::, SpaceProver>>(); - } -} diff --git a/src/multilinear/provers/time/config.rs b/src/multilinear/provers/time/config.rs deleted file mode 100644 index bc468dfc..00000000 --- a/src/multilinear/provers/time/config.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::marker::PhantomData; - -use ark_ff::Field; - -use crate::{ - multilinear::provers::time::reductions::ReduceMode, - prover::{BatchProverConfig, ProverConfig}, - streams::Stream, -}; - -pub struct TimeProverConfig -where - F: Field, - S: Stream, -{ - pub num_variables: usize, - pub streams: Vec, - pub reduce_mode: ReduceMode, - _f: PhantomData, -} - -impl TimeProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_variables: usize, stream: S, reduce_mode: ReduceMode) -> Self { - Self { - num_variables, - streams: vec![stream], - reduce_mode, - _f: PhantomData::, - } - } -} - -impl> ProverConfig for TimeProverConfig { - fn default(num_variables: usize, stream: S) -> Self { - Self { - num_variables, - streams: vec![stream], - reduce_mode: ReduceMode::Pairwise, - _f: PhantomData::, - } - } -} - -impl> BatchProverConfig for TimeProverConfig { - fn default(num_variables: usize, streams: Vec) -> Self { - Self { - num_variables, - streams, - reduce_mode: ReduceMode::Pairwise, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear/provers/time/core.rs b/src/multilinear/provers/time/core.rs deleted file mode 100644 index 88298685..00000000 --- a/src/multilinear/provers/time/core.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::multilinear::provers::time::reductions::ReduceMode; -use ark_ff::Field; -use ark_std::vec::Vec; - -use crate::streams::Stream; - -pub struct TimeProver> { - // pub claim: F, - pub current_round: usize, - pub evaluations: Option>, - pub evaluation_streams: Vec, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations - pub num_variables: usize, - pub reduce_mode: ReduceMode, -} - -impl> TimeProver { - pub fn total_rounds(&self) -> usize { - self.num_variables - } -} diff --git a/src/multilinear/provers/time/mod.rs b/src/multilinear/provers/time/mod.rs deleted file mode 100644 index 343802df..00000000 --- a/src/multilinear/provers/time/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod config; -mod core; -mod prover; -pub mod reductions; - -pub use config::TimeProverConfig; -pub use core::TimeProver; -pub use reductions::ReduceMode; diff --git a/src/multilinear/provers/time/prover.rs b/src/multilinear/provers/time/prover.rs deleted file mode 100644 index e947b452..00000000 --- a/src/multilinear/provers/time/prover.rs +++ /dev/null @@ -1,122 +0,0 @@ -use ark_ff::Field; - -use crate::multilinear::provers::time::reductions::ReduceMode; -use crate::{ - multilinear::{ - provers::time::reductions::{pairwise, variablewise}, - TimeProver, TimeProverConfig, - }, - prover::Prover, - streams::Stream, -}; - -impl> TimeProver { - fn next_message_pairwise(&mut self, verifier_message: Option) -> Option<(F, F)> { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if self.current_round != 0 { - if self.current_round > 1 { - pairwise::reduce_evaluations( - self.evaluations.as_mut().unwrap(), - verifier_message.unwrap(), - ); - } else { - self.evaluations = Some(vec![]); - pairwise::reduce_evaluations_from_stream( - &self.evaluation_streams[0], - self.evaluations.as_mut().unwrap(), - verifier_message.unwrap(), - ); - } - } - - // evaluate using vsbw - let sums = match &self.evaluations { - None => pairwise::evaluate_from_stream(&self.evaluation_streams[0]), - Some(evaluations) => pairwise::evaluate(evaluations), - }; - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial - Some(sums) - } - fn next_message_variablewise(&mut self, verifier_message: Option) -> Option<(F, F)> { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if self.current_round != 0 { - if self.current_round > 1 { - variablewise::reduce_evaluations( - self.evaluations.as_mut().unwrap(), - verifier_message.unwrap(), - F::ONE - verifier_message.unwrap(), - ); - } else { - self.evaluations = Some(vec![]); - variablewise::reduce_evaluations_from_stream( - &self.evaluation_streams[0], - self.evaluations.as_mut().unwrap(), - verifier_message.unwrap(), - F::ONE - verifier_message.unwrap(), - ); - } - } - - // evaluate using vsbw - let sums = match &self.evaluations { - None => variablewise::evaluate_from_stream(&self.evaluation_streams[0]), - Some(evaluations) => variablewise::evaluate(evaluations), - }; - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial - Some(sums) - } -} - -impl> Prover for TimeProver { - type ProverConfig = TimeProverConfig; - type ProverMessage = Option<(F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - Self { - // claim: prover_config.claim, - current_round: 0, - evaluations: None, - evaluation_streams: prover_config.streams, - num_variables: prover_config.num_variables, - reduce_mode: prover_config.reduce_mode, - } - } - - fn next_message(&mut self, verifier_message: Option) -> Option<(F, F)> { - match self.reduce_mode { - ReduceMode::Pairwise => self.next_message_pairwise(verifier_message), - ReduceMode::Variablewise => self.next_message_variablewise(verifier_message), - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear::TimeProver, - streams::MemoryStream, - tests::{multilinear::pairwise_sanity_test, F19}, - }; - - #[test] - fn sumcheck() { - pairwise_sanity_test::, TimeProver>>(); - } -} diff --git a/src/multilinear/provers/time/reductions/mod.rs b/src/multilinear/provers/time/reductions/mod.rs deleted file mode 100644 index 6895dd16..00000000 --- a/src/multilinear/provers/time/reductions/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod pairwise; -pub mod tablewise; -pub mod variablewise; - -#[derive(Copy, Clone, Debug)] -pub enum ReduceMode { - Pairwise, - Variablewise, -} diff --git a/src/multilinear/provers/time/reductions/pairwise.rs b/src/multilinear/provers/time/reductions/pairwise.rs deleted file mode 100644 index e28ce0e1..00000000 --- a/src/multilinear/provers/time/reductions/pairwise.rs +++ /dev/null @@ -1,75 +0,0 @@ -use ark_ff::Field; -use ark_std::vec::Vec; -use ark_std::{cfg_chunks, cfg_into_iter}; -#[cfg(feature = "parallel")] -use rayon::{ - iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, - prelude::ParallelSlice, -}; - -use crate::streams::Stream; - -pub fn evaluate(src: &[F]) -> (F, F) { - let even_sum = cfg_into_iter!(0..src.len()) - .step_by(2) - .map(|i| src[i]) - .sum(); - let odd_sum = cfg_into_iter!(1..src.len()) - .step_by(2) - .map(|i| src[i]) - .sum(); - (even_sum, odd_sum) -} - -pub fn evaluate_from_stream>(src: &S) -> (F, F) { - let len = 1usize << src.num_variables(); - let even_sum = cfg_into_iter!(0..len) - .step_by(2) - .map(|i| src.evaluation(i)) - .sum(); - let odd_sum = cfg_into_iter!(1..len) - .step_by(2) - .map(|i| src.evaluation(i)) - .sum(); - (even_sum, odd_sum) -} - -pub fn reduce_evaluations(src: &mut Vec, verifier_message: F) { - // compute from src - let out: Vec = cfg_chunks!(src, 2) - .map(|chunk| chunk[0] + verifier_message * (chunk[1] - chunk[0])) - .collect(); - // write back into src - src[..out.len()].copy_from_slice(&out); - src.truncate(out.len()); -} - -pub fn reduce_evaluations_from_stream>( - src: &S, - dst: &mut Vec, - verifier_message: F, -) { - // compute from stream - let len = 1usize << src.num_variables(); - let out: Vec = cfg_into_iter!(0..len / 2) - .map(|i| { - let a = src.evaluation(2 * i); - let b = src.evaluation((2 * i) + 1); - a + verifier_message * (b - a) - }) - .collect(); - *dst = out; -} - -/// Cross-field reduce: fold `BF` evaluations with an `EF` challenge, producing `Vec`. -/// -/// For each adjacent pair `(a, b)` in `src`: `EF::from(a) + challenge * (EF::from(b) - EF::from(a))`. -pub fn cross_field_reduce>(src: &[BF], challenge: EF) -> Vec { - cfg_chunks!(src, 2) - .map(|chunk| { - let a = EF::from(chunk[0]); - let b = EF::from(chunk[1]); - a + challenge * (b - a) - }) - .collect() -} diff --git a/src/multilinear/provers/time/reductions/variablewise.rs b/src/multilinear/provers/time/reductions/variablewise.rs deleted file mode 100644 index 58217038..00000000 --- a/src/multilinear/provers/time/reductions/variablewise.rs +++ /dev/null @@ -1,143 +0,0 @@ -use ark_ff::Field; -#[cfg(feature = "parallel")] -use ark_std::cfg_into_iter; -use ark_std::vec::Vec; -#[cfg(feature = "parallel")] -use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, -}; - -use crate::streams::Stream; - -pub fn evaluate(src: &[F]) -> (F, F) { - let second_half_bit: usize = src.len() / 2; - - #[cfg(feature = "parallel")] - let (sum_0, sum_1) = cfg_into_iter!(0..src.len()) - .map(|i| { - let v = src[i]; - if (i & second_half_bit) == 0 { - (v, F::zero()) - } else { - (F::zero(), v) - } - }) - .reduce( - || (F::zero(), F::zero()), - |(a0, a1), (b0, b1)| (a0 + b0, a1 + b1), - ); - - #[cfg(not(feature = "parallel"))] - let (sum_0, sum_1) = { - let mut sum_0 = F::ZERO; - let mut sum_1 = F::ZERO; - for i in 0..src.len() { - let v = src[i]; - match i & second_half_bit != 0 { - false => sum_0 += v, - true => sum_1 += v, - } - } - (sum_0, sum_1) - }; - - (sum_0, sum_1) -} - -pub fn evaluate_from_stream>(src: &S) -> (F, F) { - let len = 1usize << src.num_variables(); - let second_half_bit: usize = len / 2; - - #[cfg(feature = "parallel")] - let (sum_0, sum_1) = cfg_into_iter!(0..len) - .map(|i| { - let v = src.evaluation(i); - if (i & second_half_bit) == 0 { - (v, F::zero()) - } else { - (F::zero(), v) - } - }) - .reduce( - || (F::zero(), F::zero()), - |(a0, a1), (b0, b1)| (a0 + b0, a1 + b1), - ); - - #[cfg(not(feature = "parallel"))] - let (sum_0, sum_1) = { - let mut sum_0 = F::ZERO; - let mut sum_1 = F::ZERO; - for i in 0..len { - let v = src.evaluation(i); - match i & second_half_bit != 0 { - false => sum_0 += v, - true => sum_1 += v, - } - } - (sum_0, sum_1) - }; - - (sum_0, sum_1) -} - -pub fn reduce_evaluations( - src: &mut Vec, - verifier_message: F, - verifier_message_hat: F, -) { - let second_half_bit: usize = src.len() / 2; - let mut out = vec![F::ZERO; src.len() / 2]; - - #[cfg(feature = "parallel")] - { - out.par_iter_mut() - .enumerate() - .for_each(|(first_half_index, slot): (usize, &mut F)| { - let second_half_index = first_half_index | second_half_bit; - let v0 = src[first_half_index]; - let v1 = src[second_half_index]; - *slot = v0 * verifier_message_hat + v1 * verifier_message; - }); - } - - #[cfg(not(feature = "parallel"))] - for first_half_index in 0..src.len() / 2 { - let second_half_index = first_half_index | second_half_bit; - let v0 = src[first_half_index]; - let v1 = src[second_half_index]; - out[first_half_index] = v0 * verifier_message_hat + v1 * verifier_message; - } - - *src = out; -} - -pub fn reduce_evaluations_from_stream>( - src: &S, - dst: &mut Vec, - verifier_message: F, - verifier_message_hat: F, -) { - let len = 1usize << src.num_variables(); - let second_half_bit: usize = len / 2; - let mut out = vec![F::ZERO; len / 2]; - - #[cfg(feature = "parallel")] - out.par_iter_mut() - .enumerate() - .for_each(|(first_half_index, slot): (usize, &mut F)| { - let second_half_index = first_half_index | second_half_bit; - let v0 = src.evaluation(first_half_index); - let v1 = src.evaluation(second_half_index); - *slot = v0 * verifier_message_hat + v1 * verifier_message; - }); - - #[cfg(not(feature = "parallel"))] - for first_half_index in 0..len / 2 { - let second_half_index = first_half_index | second_half_bit; - let v0 = src.evaluation(first_half_index); - let v1 = src.evaluation(second_half_index); - out[first_half_index] = v0 * verifier_message_hat + v1 * verifier_message; - } - - *dst = out; -} diff --git a/src/multilinear/sumcheck.rs b/src/multilinear/sumcheck.rs deleted file mode 100644 index 3282f6f5..00000000 --- a/src/multilinear/sumcheck.rs +++ /dev/null @@ -1,142 +0,0 @@ -use ark_ff::Field; -use ark_std::{rand::Rng, vec::Vec}; - -use crate::{prover::Prover, streams::Stream}; - -#[derive(Debug)] -pub struct Sumcheck { - pub prover_messages: Vec<(F, F)>, - pub verifier_messages: Vec, -} - -impl Sumcheck { - pub fn prove(prover: &mut P, rng: &mut impl Rng) -> Self - where - S: Stream, - P: Prover, ProverMessage = Option<(F, F)>>, - { - // Initialize vectors to store prover and verifier messages - let mut prover_messages: Vec<(F, F)> = vec![]; - let mut verifier_messages: Vec = vec![]; - - // Run the protocol - let mut verifier_message: Option = None; - while let Some(message) = prover.next_message(verifier_message) { - let round_sum = message.0 + message.1; - let is_round_accepted = match verifier_message { - // If first round, compare to claimed_sum - None => true, // TODO (z-tech): we should give an option to supply claim round_sum == prover.claim(), - // Else compute f(prev_verifier_msg) = prev_sum_0 - (prev_sum_0 - prev_sum_1) * prev_verifier_msg == round_sum, store verifier message - Some(prev_verifier_message) => { - verifier_messages.push(prev_verifier_message); - let prev_prover_message = prover_messages.last().unwrap(); - round_sum - == prev_prover_message.0 - - (prev_prover_message.0 - prev_prover_message.1) - * prev_verifier_message - } - }; - - // Handle how to proceed - prover_messages.push(message); - if !is_round_accepted { - break; - } - - verifier_message = Some(F::rand(rng)); - } - - // Return a Sumcheck struct with the collected messages and acceptance status - Sumcheck { - prover_messages, - verifier_messages, - } - } -} - -#[cfg(test)] -mod tests { - use super::Sumcheck; - use crate::streams::Stream; - use crate::{ - multilinear::{BlendyProver, BlendyProverConfig, ReduceMode, TimeProver}, - prover::{Prover, ProverConfig}, - tests::{ - multilinear::{BasicProver, BasicProverConfig}, - polynomials::Polynomial, - BenchStream, F19, - }, - }; - use ark_poly::multivariate; - - #[test] - fn sanity() { - const NUM_VARIABLES: usize = 16; - - // take an evaluation stream - let evaluation_stream: BenchStream = BenchStream::new(NUM_VARIABLES); - let claim = evaluation_stream.claimed_sum; - - // 1) blendy - let mut blendy_k3_prover = BlendyProver::>::new( - BlendyProverConfig::new(3, NUM_VARIABLES, evaluation_stream.clone()), - ); - let blendy_prover_transcript = Sumcheck::::prove::< - BenchStream, - BlendyProver>, - >(&mut blendy_k3_prover, &mut ark_std::test_rng()); - - // 2) time_prover_variablewise - let mut time_prover_variablewise = TimeProver::>::new(, - > as Prover>::ProverConfig::new( - NUM_VARIABLES, - evaluation_stream.clone(), - ReduceMode::Variablewise, - )); - let time_prover_variablewise_transcript = - Sumcheck::::prove::, TimeProver>>( - &mut time_prover_variablewise, - &mut ark_std::test_rng(), - ); - - // 3) basic prover - let s_evaluations: Vec = (0..1 << NUM_VARIABLES) - .map(|i| evaluation_stream.evaluation(i)) - .collect(); - let p = as Polynomial< - F19, - >>::from_hypercube_evaluations(s_evaluations); - let mut basic_prover = - BasicProver::::new(BasicProverConfig::new(claim, NUM_VARIABLES, p)); - let basic_prover_transcript = Sumcheck::::prove::, BasicProver>( - &mut basic_prover, - &mut ark_std::test_rng(), - ); - - // ensure all transcripts (1, 2, 3) identical - assert_eq!( - time_prover_variablewise_transcript.prover_messages, - blendy_prover_transcript.prover_messages - ); - assert_eq!( - time_prover_variablewise_transcript.prover_messages, - basic_prover_transcript.prover_messages - ); - - // time_prover_pairwise: this should pass but I have nothing to compare it with - let mut time_prover_pairwise = TimeProver::>::new(, - > as Prover>::ProverConfig::default( - NUM_VARIABLES, - evaluation_stream, - )); - let _time_prover_pairwise_transcript = - Sumcheck::::prove::, TimeProver>>( - &mut time_prover_pairwise, - &mut ark_std::test_rng(), - ); - } -} diff --git a/src/multilinear_product/mod.rs b/src/multilinear_product/mod.rs deleted file mode 100644 index 91a1d5b5..00000000 --- a/src/multilinear_product/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod provers; -mod sumcheck; - -pub use provers::{ - blendy::{BlendyProductProver, BlendyProductProverConfig}, - space::{SpaceProductProver, SpaceProductProverConfig}, - time::{TimeProductProver, TimeProductProverConfig}, -}; -pub use sumcheck::ProductSumcheck; diff --git a/src/multilinear_product/provers/blendy/config.rs b/src/multilinear_product/provers/blendy/config.rs deleted file mode 100644 index 2477360e..00000000 --- a/src/multilinear_product/provers/blendy/config.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::marker::PhantomData; - -use ark_ff::Field; - -use crate::{prover::ProductProverConfig, streams::Stream}; - -const DEFAULT_NUM_STAGES: usize = 2; - -pub struct BlendyProductProverConfig -where - F: Field, - S: Stream, -{ - pub num_stages: usize, - pub num_variables: usize, - pub streams: Vec, - _f: PhantomData, -} - -impl BlendyProductProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_stages: usize, num_variables: usize, streams: Vec) -> Self { - Self { - num_stages, - num_variables, - streams, - _f: PhantomData::, - } - } -} - -impl> ProductProverConfig for BlendyProductProverConfig { - fn default(num_variables: usize, streams: Vec) -> Self { - Self { - num_stages: DEFAULT_NUM_STAGES, - num_variables, - streams, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear_product/provers/blendy/core.rs b/src/multilinear_product/provers/blendy/core.rs deleted file mode 100644 index efbdc1e2..00000000 --- a/src/multilinear_product/provers/blendy/core.rs +++ /dev/null @@ -1,269 +0,0 @@ -use crate::{ - hypercube::Hypercube, - interpolation::LagrangePolynomial, - messages::VerifierMessages, - multilinear_product::TimeProductProver, - order_strategy::{GraycodeOrder, MSBOrder}, - streams::{Stream, StreamIterator}, -}; -use ark_ff::Field; -use ark_std::vec::Vec; -use std::collections::BTreeSet; - -pub struct BlendyProductProver> { - pub current_round: usize, - pub streams: Vec, - pub stream_iterators: Vec>, - pub num_stages: usize, - pub num_variables: usize, - pub last_round_phase1: usize, - pub verifier_messages: VerifierMessages, - pub verifier_messages_round_comp: VerifierMessages, - pub x_table: Vec, - pub y_table: Vec, - pub j_prime_table: Vec>, - pub stage_size: usize, - pub inverse_four: F, - pub prev_table_round_num: usize, - pub prev_table_size: usize, - pub state_comp_set: BTreeSet, - pub switched_to_vsbw: bool, - pub vsbw_prover: TimeProductProver, -} - -impl> BlendyProductProver { - pub fn is_initial_round(&self) -> bool { - self.current_round == 0 - } - - pub fn total_rounds(&self) -> usize { - self.num_variables - } - - pub fn init_round_vars(&mut self) { - let n = self.num_variables; - let j = self.current_round + 1; - - if let Some(&prev_round) = self.state_comp_set.range(..=j).next_back() { - self.prev_table_round_num = prev_round; - if let Some(&next_round) = self.state_comp_set.range((j + 1)..).next() { - self.prev_table_size = next_round - prev_round; - } else { - self.prev_table_size = n + 1 - prev_round; - } - } else { - self.prev_table_round_num = 0; - self.prev_table_size = 0; - } - } - - pub fn compute_round(&mut self) -> (F, F, F) { - let mut sum_0 = F::ZERO; - let mut sum_1 = F::ZERO; - let mut sum_half = F::ZERO; - - // in the last rounds, we switch to the memory intensive prover - if self.switched_to_vsbw { - (sum_0, sum_1, sum_half) = self.vsbw_prover.vsbw_evaluate(); - } - // if first few rounds, then no table is computed, need to compute sums from the streams - else if self.current_round < self.last_round_phase1 { - // Lag Poly - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages_round_comp); - let lag_polys_len = Hypercube::::stop_value(self.current_round); - let mut lag_polys: Vec = vec![F::ONE; lag_polys_len]; - - // reset the streams - self.stream_iterators - .iter_mut() - .for_each(|stream_it| stream_it.reset()); - - for (x_index, _) in - Hypercube::::new(self.num_variables - self.current_round - 1) - { - // can avoid unnecessary additions for first round since there is no lag poly: gives a small speedup - if self.is_initial_round() { - let p0 = self.stream_iterators[0].next().unwrap(); - let p1 = self.stream_iterators[0].next().unwrap(); - let q0 = self.stream_iterators[1].next().unwrap(); - let q1 = self.stream_iterators[1].next().unwrap(); - sum_0 += p0 * q0; - sum_1 += p1 * q1; - sum_half += (p0 + p1) * (q0 + q1); - } else { - let mut partial_sum_p_0 = F::ZERO; - let mut partial_sum_p_1 = F::ZERO; - let mut partial_sum_q_0 = F::ZERO; - let mut partial_sum_q_1 = F::ZERO; - for (b_index, _) in Hypercube::::new(self.current_round) { - if x_index == 0 { - lag_polys[b_index] = sequential_lag_poly.next().unwrap(); - } - let lag_poly = lag_polys[b_index]; - partial_sum_p_0 += self.stream_iterators[0].next().unwrap() * lag_poly; - partial_sum_q_0 += self.stream_iterators[1].next().unwrap() * lag_poly; - } - for (b_index, _) in Hypercube::::new(self.current_round) { - let lag_poly = lag_polys[b_index]; - partial_sum_p_1 += self.stream_iterators[0].next().unwrap() * lag_poly; - partial_sum_q_1 += self.stream_iterators[1].next().unwrap() * lag_poly; - } - - sum_0 += partial_sum_p_0 * partial_sum_q_0; - sum_1 += partial_sum_p_1 * partial_sum_q_1; - sum_half += - (partial_sum_p_0 + partial_sum_p_1) * (partial_sum_q_0 + partial_sum_q_1); - } - } - sum_half *= self.inverse_four; - } else { - // computing evaluations from the cross product tables - - // things to help iterating - let b_prime_num_vars = self.current_round + 1 - self.prev_table_round_num; - let v_num_vars: usize = - self.prev_table_size + self.prev_table_round_num - self.current_round - 2; - let b_prime_index_left_shift = v_num_vars + 1; - - // Lag Poly - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages_round_comp); - let lag_polys_len = Hypercube::::stop_value(b_prime_num_vars); - let mut lag_polys: Vec = vec![F::ONE; lag_polys_len]; - - // Sums - for (b_prime_index, _) in Hypercube::::new(b_prime_num_vars) { - for (b_prime_prime_index, _) in Hypercube::::new(b_prime_num_vars) { - // doing it like this, for each hypercube member lag_poly is computed exactly once - if b_prime_index == 0 { - lag_polys[b_prime_prime_index] = sequential_lag_poly.next().unwrap(); - } - - let lag_poly_1 = lag_polys[b_prime_index]; - let lag_poly_2 = lag_polys[b_prime_prime_index]; - let lag_poly = lag_poly_1 * lag_poly_2; - for (v_index, _) in Hypercube::::new(v_num_vars) { - let b_prime_0_v = - b_prime_index << b_prime_index_left_shift | 0 << v_num_vars | v_index; - let b_prime_prime_0_v = b_prime_prime_index << b_prime_index_left_shift - | 0 << v_num_vars - | v_index; - let b_prime_1_v = - b_prime_index << b_prime_index_left_shift | 1 << v_num_vars | v_index; - let b_prime_prime_1_v = b_prime_prime_index << b_prime_index_left_shift - | 1 << v_num_vars - | v_index; - - sum_0 += lag_poly * self.j_prime_table[b_prime_0_v][b_prime_prime_0_v]; - sum_1 += lag_poly * self.j_prime_table[b_prime_1_v][b_prime_prime_1_v]; - sum_half += lag_poly - * (self.j_prime_table[b_prime_0_v][b_prime_prime_0_v] - + self.j_prime_table[b_prime_0_v][b_prime_prime_1_v] - + self.j_prime_table[b_prime_1_v][b_prime_prime_0_v] - + self.j_prime_table[b_prime_1_v][b_prime_prime_1_v]); - } - } - } - sum_half *= self.inverse_four; - } - (sum_0, sum_1, sum_half) - } - - pub fn compute_state(&mut self) { - let j = self.current_round + 1; - let p = self.state_comp_set.contains(&j); - let is_largest = self.state_comp_set.range((j + 1)..).next().is_none(); - if p && !is_largest { - let j_prime = self.prev_table_round_num; - let t = self.prev_table_size; - - // zero out the table - let table_len = Hypercube::::stop_value(t); - self.j_prime_table = vec![vec![F::ZERO; table_len]; table_len]; - - // basically, this needs to get "zeroed" out at the beginning of state computation - self.verifier_messages_round_comp = VerifierMessages::new_from_self( - &self.verifier_messages, - j_prime - 1, - self.verifier_messages.messages.len(), - ); - - // some stuff for iterating - let b_num_vars: usize = self.num_variables + 1 - j_prime - t; - let x_num_vars = j_prime - 1; - - // Lag Poly - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages); - - assert!(x_num_vars == self.verifier_messages.messages.len()); - let lag_polys_len = Hypercube::::stop_value(x_num_vars); - let mut lag_polys: Vec = vec![F::ONE; lag_polys_len]; - - for (x_index, _) in Hypercube::::new(x_num_vars) { - lag_polys[x_index] = sequential_lag_poly.next().unwrap(); - } - - // reset the streams - self.stream_iterators - .iter_mut() - .for_each(|stream_it| stream_it.reset()); - - // Ensure x_table and y_table are initialized with the correct size - self.x_table = vec![F::ZERO; Hypercube::::stop_value(t)]; - self.y_table = vec![F::ZERO; Hypercube::::stop_value(t)]; - - for (_, _) in Hypercube::::new(b_num_vars) { - for (b_prime_index, _) in Hypercube::::new(t) { - self.x_table[b_prime_index] = F::ZERO; - self.y_table[b_prime_index] = F::ZERO; - - for (x_index, _) in Hypercube::::new(x_num_vars) { - self.x_table[b_prime_index] += - lag_polys[x_index] * self.stream_iterators[0].next().unwrap(); - self.y_table[b_prime_index] += - lag_polys[x_index] * self.stream_iterators[1].next().unwrap(); - } - } - for (b_prime_index, _) in Hypercube::::new(t) { - for (b_prime_prime_index, _) in Hypercube::::new(t) { - self.j_prime_table[b_prime_index][b_prime_prime_index] += - self.x_table[b_prime_index] * self.y_table[b_prime_prime_index]; - } - } - } - } else if p && is_largest { - // switch to the memory intensive sumcheck on the last round computation - let num_variables_new = self.num_variables - j + 1; - self.switched_to_vsbw = true; - - // reset the streams - self.stream_iterators - .iter_mut() - .for_each(|stream_it| stream_it.reset()); - - // initialize the evaluations for the memory-intensive implementation for the final rounds of the protocol - let mut evaluations_p = vec![F::ZERO; 1 << num_variables_new]; - let mut evaluations_q = vec![F::ZERO; 1 << num_variables_new]; - - for (b_prime_index, _) in Hypercube::::new(num_variables_new) { - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages); - for (_, _) in Hypercube::::new(j - 1) { - let lag_poly = sequential_lag_poly.next().unwrap(); - evaluations_p[b_prime_index] += - lag_poly * self.stream_iterators[0].next().unwrap(); - evaluations_q[b_prime_index] += - lag_poly * self.stream_iterators[1].next().unwrap(); - } - } - self.vsbw_prover.evaluations[0] = Some(evaluations_p); - self.vsbw_prover.evaluations[1] = Some(evaluations_q); - } else if self.switched_to_vsbw { - let verifier_message = self.verifier_messages.messages[self.current_round - 1]; - self.vsbw_prover - .vsbw_reduce_evaluations(verifier_message, F::ONE - verifier_message); - } - } -} diff --git a/src/multilinear_product/provers/blendy/mod.rs b/src/multilinear_product/provers/blendy/mod.rs deleted file mode 100644 index 0c1898d1..00000000 --- a/src/multilinear_product/provers/blendy/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::BlendyProductProverConfig; -pub use core::BlendyProductProver; diff --git a/src/multilinear_product/provers/blendy/prover.rs b/src/multilinear_product/provers/blendy/prover.rs deleted file mode 100644 index 49c1a72b..00000000 --- a/src/multilinear_product/provers/blendy/prover.rs +++ /dev/null @@ -1,179 +0,0 @@ -use ark_ff::Field; -use std::collections::BTreeSet; - -use crate::{ - messages::VerifierMessages, - multilinear::ReduceMode, - multilinear_product::{BlendyProductProver, BlendyProductProverConfig, TimeProductProver}, - order_strategy::MSBOrder, - prover::Prover, - streams::{Stream, StreamIterator}, -}; - -impl> Prover for BlendyProductProver { - type ProverConfig = BlendyProductProverConfig; - type ProverMessage = Option<(F, F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - let num_variables: usize = prover_config.num_variables; - let num_stages: usize = prover_config.num_stages; - let stage_size: usize = num_variables / num_stages; - let max_rounds_phase2: usize = num_variables.div_ceil(2 * num_stages); - - let last_round_phase1: usize = 2; - let last_round_phase3: usize = num_variables - num_variables.div_ceil(num_stages); - - let state_comp_set: BTreeSet = { - let mut current_round: usize = last_round_phase1 + 1; - let mut state_comp_set: BTreeSet = BTreeSet::new(); - while current_round <= last_round_phase3 { - state_comp_set.insert(current_round); - current_round = - std::cmp::min(current_round + max_rounds_phase2, current_round * 2 - 1); // the minus one is a time-efficiency optimization - current_round = std::cmp::max(current_round, 2); - } - state_comp_set - }; - assert!(!state_comp_set.is_empty()); - - let last_round: usize = *state_comp_set.iter().max().unwrap(); - let vsbw_prover = TimeProductProver:: { - current_round: 0, - evaluations: vec![None; 2], - streams: None, - num_variables: num_variables - last_round + 1, - inverse_four: F::from(4_u32).inverse().unwrap(), - reduce_mode: ReduceMode::Variablewise, - }; - - let stream_iterators = prover_config - .streams - .iter() - .cloned() - .map(|s| StreamIterator::::new(s)) - .collect(); - - // return the BlendyProver instance - Self { - current_round: 0, - streams: prover_config.streams, - stream_iterators, - num_stages, - num_variables, - last_round_phase1, - verifier_messages: VerifierMessages::new(&vec![]), - verifier_messages_round_comp: VerifierMessages::new(&vec![]), - x_table: vec![], - y_table: vec![], - j_prime_table: vec![], - stage_size, - inverse_four: F::from(4_u32).inverse().unwrap(), - prev_table_round_num: 0, - prev_table_size: 0, - state_comp_set, - switched_to_vsbw: false, - vsbw_prover, - } - } - - fn next_message(&mut self, verifier_message: Self::VerifierMessage) -> Self::ProverMessage { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if !self.is_initial_round() { - // this holds everything - self.verifier_messages - .receive_message(verifier_message.unwrap()); - // this holds the randomness for between state computation r2 - self.verifier_messages_round_comp - .receive_message(verifier_message.unwrap()); - } - - self.init_round_vars(); - - self.compute_state(); - - let sums: (F, F, F) = self.compute_round(); - - // Increment the round counter - self.current_round += 1; - if self.switched_to_vsbw { - self.vsbw_prover.current_round += 1; - } - // Return the computed polynomial sums - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; - - use crate::{ - multilinear_product::{BlendyProductProver, BlendyProductProverConfig}, - order_strategy::MSBOrder, - prover::{ProductProverConfig, Prover}, - streams::{multivariate_product_claim, MemoryStream, Stream}, - tests::{ - multilinear_product::{BasicProductProver, BasicProductProverConfig}, - polynomials::Polynomial, - BenchStream, F64, - }, - ProductSumcheck, - }; - - // the stream has to be in SigBit order for this to work - // #[test] - // fn parity_with_basic_prover() { - // consistency_test::, BlendyProductProver>>(); - // } - - #[test] - fn consistency_test_with_next_iterator() { - // get evals in lexicographic order - let num_variables = 8; - let s_tmp: BenchStream = BenchStream::::new(num_variables); - let mut evals: Vec = Vec::with_capacity(1 << num_variables); - for i in 0..(1 << num_variables) { - evals.push(s_tmp.evaluation(i)); - } - - // create the stream in SigBit order - let s: MemoryStream = MemoryStream::new_from_lex::(evals.clone()); - let claim: F64 = multivariate_product_claim(vec![s.clone(), s.clone()]); - - // get transcript from Blendy prover - let prover_transcript: ProductSumcheck = ProductSumcheck::::prove::< - MemoryStream, - BlendyProductProver>, - >( - &mut Prover::::new(BlendyProductProverConfig::default( - num_variables, - vec![s.clone(), s.clone()], - )), - &mut ark_std::test_rng(), - ); - - // get transcript from SanityProver - let p: SparsePolynomial = - as Polynomial>::from_hypercube_evaluations( - s.evaluations.clone(), - ); - let mut sanity_prover = BasicProductProver::::new(BasicProductProverConfig::new( - claim, - num_variables, - p.clone(), - p, - )); - let sanity_prover_transcript = ProductSumcheck::::prove::< - MemoryStream, - BasicProductProver, - >(&mut sanity_prover, &mut ark_std::test_rng()); - - // ensure the transcript is identical - assert_eq!(prover_transcript, sanity_prover_transcript); - } -} diff --git a/src/multilinear_product/provers/mod.rs b/src/multilinear_product/provers/mod.rs deleted file mode 100644 index 7e2173bd..00000000 --- a/src/multilinear_product/provers/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod blendy; -pub mod space; -pub mod time; diff --git a/src/multilinear_product/provers/space/config.rs b/src/multilinear_product/provers/space/config.rs deleted file mode 100644 index 89e1dac8..00000000 --- a/src/multilinear_product/provers/space/config.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::marker::PhantomData; - -use ark_ff::Field; - -use crate::{prover::ProductProverConfig, streams::Stream}; - -pub struct SpaceProductProverConfig -where - F: Field, - S: Stream, -{ - pub num_variables: usize, - pub streams: Vec, - _f: PhantomData, -} - -impl SpaceProductProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_variables: usize, streams: Vec) -> Self { - Self { - num_variables, - streams, - _f: PhantomData::, - } - } -} - -impl> ProductProverConfig for SpaceProductProverConfig { - fn default(num_variables: usize, streams: Vec) -> Self { - Self { - num_variables, - streams, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear_product/provers/space/core.rs b/src/multilinear_product/provers/space/core.rs deleted file mode 100644 index 5a0ae45d..00000000 --- a/src/multilinear_product/provers/space/core.rs +++ /dev/null @@ -1,71 +0,0 @@ -use ark_ff::Field; - -use crate::{ - hypercube::Hypercube, - interpolation::LagrangePolynomial, - messages::VerifierMessages, - order_strategy::MSBOrder, - streams::{Stream, StreamIterator}, -}; - -pub struct SpaceProductProver> { - pub current_round: usize, - pub stream_iterators: Vec>, - pub num_variables: usize, - pub verifier_messages: VerifierMessages, - pub inverse_four: F, -} - -impl> SpaceProductProver { - pub fn cty_evaluate(&mut self) -> (F, F, F) { - let mut sum_0: F = F::ZERO; - let mut sum_1: F = F::ZERO; - let mut sum_half: F = F::ZERO; - - // reset the streams - self.stream_iterators - .iter_mut() - .for_each(|stream_it| stream_it.reset()); - - for (_, _) in Hypercube::::new(self.num_variables - self.current_round - 1) { - // can avoid unnecessary additions for first round since there is no lag poly: gives a small speedup - if self.current_round == 0 { - let p0 = self.stream_iterators[0].next().unwrap(); - let p1 = self.stream_iterators[0].next().unwrap(); - let q0 = self.stream_iterators[1].next().unwrap(); - let q1 = self.stream_iterators[1].next().unwrap(); - sum_0 += p0 * q0; - sum_1 += p1 * q1; - sum_half += (p0 + p1) * (q0 + q1); - } else { - let mut partial_sum_p_0 = F::ZERO; - let mut partial_sum_p_1 = F::ZERO; - let mut partial_sum_q_0 = F::ZERO; - let mut partial_sum_q_1 = F::ZERO; - - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages); - for (_, _) in Hypercube::::new(self.current_round) { - let lag_poly = sequential_lag_poly.next().unwrap(); - partial_sum_p_0 += self.stream_iterators[0].next().unwrap() * lag_poly; - partial_sum_q_0 += self.stream_iterators[1].next().unwrap() * lag_poly; - } - - let mut sequential_lag_poly: LagrangePolynomial = - LagrangePolynomial::new(&self.verifier_messages); - for (_, _) in Hypercube::::new(self.current_round) { - let lag_poly = sequential_lag_poly.next().unwrap(); - partial_sum_p_1 += self.stream_iterators[0].next().unwrap() * lag_poly; - partial_sum_q_1 += self.stream_iterators[1].next().unwrap() * lag_poly; - } - - sum_0 += partial_sum_p_0 * partial_sum_q_0; - sum_1 += partial_sum_p_1 * partial_sum_q_1; - sum_half += - (partial_sum_p_0 + partial_sum_p_1) * (partial_sum_q_0 + partial_sum_q_1); - } - } - sum_half *= self.inverse_four; - (sum_0, sum_1, sum_half) - } -} diff --git a/src/multilinear_product/provers/space/mod.rs b/src/multilinear_product/provers/space/mod.rs deleted file mode 100644 index 5ad9937c..00000000 --- a/src/multilinear_product/provers/space/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::SpaceProductProverConfig; -pub use core::SpaceProductProver; diff --git a/src/multilinear_product/provers/space/prover.rs b/src/multilinear_product/provers/space/prover.rs deleted file mode 100644 index 45da2194..00000000 --- a/src/multilinear_product/provers/space/prover.rs +++ /dev/null @@ -1,67 +0,0 @@ -use ark_ff::Field; - -use crate::{ - messages::VerifierMessages, - multilinear_product::{SpaceProductProver, SpaceProductProverConfig}, - order_strategy::MSBOrder, - prover::Prover, - streams::{Stream, StreamIterator}, -}; - -impl> Prover for SpaceProductProver { - type ProverConfig = SpaceProductProverConfig; - type ProverMessage = Option<(F, F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - let stream_iterators = prover_config - .streams - .iter() - .cloned() - .map(|s| StreamIterator::::new(s)) - .collect(); - - Self { - stream_iterators, - verifier_messages: VerifierMessages::new(&vec![]), - current_round: 0, - num_variables: prover_config.num_variables, - inverse_four: F::from(4_u32).inverse().unwrap(), - } - } - - fn next_message(&mut self, verifier_message: Self::VerifierMessage) -> Self::ProverMessage { - // Ensure the current round is within bounds - if self.current_round >= self.num_variables { - return None; - } - - // If it's not the first round, add the verifier message to verifier_messages - if self.current_round != 0 { - self.verifier_messages - .receive_message(verifier_message.unwrap()); - } - - // evaluate using cty - let sums: (F, F, F) = self.cty_evaluate(); - - // don't forget to increment the round - self.current_round += 1; - - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear_product::SpaceProductProver, - streams::MemoryStream, - tests::{multilinear_product::sanity_test, F19}, - }; - - #[test] - fn sumcheck() { - sanity_test::, SpaceProductProver>>(); - } -} diff --git a/src/multilinear_product/provers/time/config.rs b/src/multilinear_product/provers/time/config.rs deleted file mode 100644 index f181d2b5..00000000 --- a/src/multilinear_product/provers/time/config.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::marker::PhantomData; - -use ark_ff::Field; - -use crate::{multilinear::ReduceMode, prover::ProductProverConfig, streams::Stream}; - -pub struct TimeProductProverConfig -where - F: Field, - S: Stream, -{ - pub num_variables: usize, - pub streams: Vec, - pub reduce_mode: ReduceMode, - _f: PhantomData, -} - -impl TimeProductProverConfig -where - F: Field, - S: Stream, -{ - pub fn new(num_variables: usize, streams: Vec, reduce_mode: ReduceMode) -> Self { - Self { - num_variables, - streams, - reduce_mode, - _f: PhantomData::, - } - } -} - -impl> ProductProverConfig for TimeProductProverConfig { - fn default(num_variables: usize, streams: Vec) -> Self { - Self { - num_variables, - streams, - reduce_mode: ReduceMode::Variablewise, - _f: PhantomData::, - } - } -} diff --git a/src/multilinear_product/provers/time/core.rs b/src/multilinear_product/provers/time/core.rs deleted file mode 100644 index 549f6d7c..00000000 --- a/src/multilinear_product/provers/time/core.rs +++ /dev/null @@ -1,100 +0,0 @@ -use ark_ff::Field; -use ark_std::vec::Vec; - -use crate::{ - multilinear::{ - reductions::{pairwise, variablewise}, - ReduceMode, - }, - multilinear_product::provers::time::reductions::{ - pairwise::{pairwise_product_evaluate, pairwise_product_evaluate_from_stream}, - variablewise::{variablewise_product_evaluate, variablewise_product_evaluate_from_stream}, - }, - streams::Stream, -}; - -pub struct TimeProductProver> { - pub current_round: usize, - pub evaluations: Vec>>, - pub streams: Option>, - pub num_variables: usize, - pub inverse_four: F, - pub reduce_mode: ReduceMode, -} - -impl> TimeProductProver { - pub fn total_rounds(&self) -> usize { - self.num_variables - } - pub fn num_free_variables(&self) -> usize { - self.num_variables - self.current_round - } - /* - * Note in evaluate() there's an optimization for the first round where we read directly - * from the streams (instead of the tables), which reduces max memory usage by 1/2 - */ - pub fn vsbw_evaluate(&self) -> (F, F, F) { - match &self.evaluations[0] { - None => match self.reduce_mode { - ReduceMode::Variablewise => variablewise_product_evaluate_from_stream( - &self.streams.clone().unwrap(), - self.inverse_four, - ), - ReduceMode::Pairwise => { - pairwise_product_evaluate_from_stream(&self.streams.clone().unwrap()) - } - }, - Some(_evals) => { - let evals: Vec> = self - .evaluations - .iter() - .filter_map(|opt| opt.clone()) // keep only Some(&Vec) - .collect(); - let evals_slice: &[Vec] = &evals; - match self.reduce_mode { - ReduceMode::Variablewise => { - variablewise_product_evaluate(evals_slice, self.inverse_four) - } - ReduceMode::Pairwise => pairwise_product_evaluate(evals_slice), - } - } - } - } - pub fn vsbw_reduce_evaluations(&mut self, verifier_message: F, verifier_message_hat: F) { - match &self.evaluations[0] { - None => { - let len = self.streams.clone().unwrap().len(); - for i in 0..len { - self.evaluations[i] = Some(vec![]); - match self.reduce_mode { - ReduceMode::Variablewise => variablewise::reduce_evaluations_from_stream( - &self.streams.as_mut().unwrap()[i], - self.evaluations[i].as_mut().unwrap(), - verifier_message, - verifier_message_hat, - ), - ReduceMode::Pairwise => pairwise::reduce_evaluations_from_stream( - &self.streams.as_mut().unwrap()[i], - self.evaluations[i].as_mut().unwrap(), - verifier_message, - ), - } - } - } - Some(_a) => { - for table in &mut self.evaluations { - match self.reduce_mode { - ReduceMode::Variablewise => variablewise::reduce_evaluations( - table.as_mut().unwrap(), - verifier_message, - verifier_message_hat, - ), - ReduceMode::Pairwise => { - pairwise::reduce_evaluations(table.as_mut().unwrap(), verifier_message) - } - } - } - } - } - } -} diff --git a/src/multilinear_product/provers/time/mod.rs b/src/multilinear_product/provers/time/mod.rs deleted file mode 100644 index 590a8d21..00000000 --- a/src/multilinear_product/provers/time/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod config; -mod core; -mod prover; -pub mod reductions; - -pub use config::TimeProductProverConfig; -pub use core::TimeProductProver; diff --git a/src/multilinear_product/provers/time/prover.rs b/src/multilinear_product/provers/time/prover.rs deleted file mode 100644 index f5ad2a42..00000000 --- a/src/multilinear_product/provers/time/prover.rs +++ /dev/null @@ -1,63 +0,0 @@ -use ark_ff::Field; - -use crate::{ - multilinear_product::{TimeProductProver, TimeProductProverConfig}, - prover::Prover, - streams::Stream, -}; - -impl> Prover for TimeProductProver { - type ProverConfig = TimeProductProverConfig; - type ProverMessage = Option<(F, F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - let num_variables = prover_config.num_variables; - Self { - current_round: 0, - evaluations: vec![None; prover_config.streams.len()], - streams: Some(prover_config.streams), - num_variables, - inverse_four: F::from(4_u32).inverse().unwrap(), - reduce_mode: prover_config.reduce_mode, - } - } - - fn next_message(&mut self, verifier_message: Option) -> Option<(F, F, F)> { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - // If it's not the first round, reduce the evaluations table - if self.current_round != 0 { - // update the evaluations table by absorbing leftmost variable assigned to verifier_message - self.vsbw_reduce_evaluations( - verifier_message.unwrap(), - F::ONE - verifier_message.unwrap(), - ); - } - - // evaluate using vsbw - let sums = self.vsbw_evaluate(); - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear_product::TimeProductProver, - tests::{multilinear_product::consistency_test, BenchStream, F64}, - }; - - #[test] - fn parity_with_basic_prover() { - consistency_test::, TimeProductProver>>(); - } -} diff --git a/src/multilinear_product/provers/time/reductions/mod.rs b/src/multilinear_product/provers/time/reductions/mod.rs deleted file mode 100644 index 3e0fccfb..00000000 --- a/src/multilinear_product/provers/time/reductions/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod pairwise; -pub mod variablewise; diff --git a/src/multilinear_product/provers/time/reductions/pairwise.rs b/src/multilinear_product/provers/time/reductions/pairwise.rs deleted file mode 100644 index fc482a80..00000000 --- a/src/multilinear_product/provers/time/reductions/pairwise.rs +++ /dev/null @@ -1,74 +0,0 @@ -use ark_ff::Field; -use ark_std::cfg_into_iter; - -#[cfg(feature = "parallel")] -use rayon::iter::{IntoParallelIterator, ParallelIterator}; - -use crate::streams::Stream; - -pub fn pairwise_product_evaluate(src: &[Vec]) -> (F, F, F) { - let half_len = src[0].len() / 2; - let sum00: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p0 = src[0][i]; - let q0 = src[1][i]; - p0 * q0 - }) - .sum(); - - let sum11: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p1 = src[0][i + 1]; - let q1 = src[1][i + 1]; - p1 * q1 - }) - .sum(); - - let sum0110: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p0 = src[0][i]; - let p1 = src[0][i + 1]; - let q0 = src[1][i]; - let q1 = src[1][i + 1]; - p0 * q1 + p1 * q0 - }) - .sum(); - (sum00, sum11, sum0110) -} - -pub fn pairwise_product_evaluate_from_stream>(src: &[S]) -> (F, F, F) { - let len = 1usize << src[0].num_variables(); - let half_len = len / 2; - let sum00: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p0 = src[0].evaluation(i); - let q0 = src[1].evaluation(i); - p0 * q0 - }) - .sum(); - - let sum11: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p1 = src[0].evaluation(i + 1); - let q1 = src[1].evaluation(i + 1); - p1 * q1 - }) - .sum(); - - let sum0110: F = cfg_into_iter!(0..half_len) - .map(|k| { - let i = 2 * k; - let p0 = src[0].evaluation(i); - let p1 = src[0].evaluation(i + 1); - let q0 = src[1].evaluation(i); - let q1 = src[1].evaluation(i + 1); - p0 * q1 + p1 * q0 - }) - .sum(); - (sum00, sum11, sum0110) -} diff --git a/src/multilinear_product/provers/time/reductions/variablewise.rs b/src/multilinear_product/provers/time/reductions/variablewise.rs deleted file mode 100644 index 17a4bef6..00000000 --- a/src/multilinear_product/provers/time/reductions/variablewise.rs +++ /dev/null @@ -1,104 +0,0 @@ -use ark_ff::Field; -use ark_std::cfg_into_iter; - -#[cfg(feature = "parallel")] -use rayon::iter::{IntoParallelIterator, ParallelIterator}; - -use crate::streams::Stream; - -pub fn variablewise_product_evaluate(src: &[Vec], inverse_four: F) -> (F, F, F) { - let len = src[0].len(); - let second_half_bit: usize = len / 2; - - let p_evals = &src[0]; - let q_evals = &src[1]; - - let acc00: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p0 = p_evals[i]; - let q0 = q_evals[i]; - p0 * q0 - }) - .sum(); - - let acc11: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p1 = p_evals[i | second_half_bit]; - let q1 = q_evals[i | second_half_bit]; - p1 * q1 - }) - .sum(); - - let acc01: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p0 = p_evals[i]; - let q1 = q_evals[i | second_half_bit]; - p0 * q1 - }) - .sum(); - - let acc10: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p1 = p_evals[i | second_half_bit]; - let q0 = q_evals[i]; - p1 * q0 - }) - .sum(); - - let sum_0 = acc00; - let sum_1 = acc11; - let mut sum_half = acc00 + acc11 + acc01 + acc10; - sum_half *= inverse_four; - - (sum_0, sum_1, sum_half) -} - -pub fn variablewise_product_evaluate_from_stream>( - src: &[S], - inverse_four: F, -) -> (F, F, F) { - let len = 1usize << src[0].num_variables(); - let second_half_bit: usize = len / 2; - - let p_evals = &src[0]; - let q_evals = &src[1]; - - let acc00: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p0 = p_evals.evaluation(i); - let q0 = q_evals.evaluation(i); - p0 * q0 - }) - .sum(); - - let acc11: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p1 = p_evals.evaluation(i | second_half_bit); - let q1 = q_evals.evaluation(i | second_half_bit); - p1 * q1 - }) - .sum(); - - let acc01: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p0 = p_evals.evaluation(i); - let q1 = q_evals.evaluation(i | second_half_bit); - p0 * q1 - }) - .sum(); - - let acc10: F = cfg_into_iter!(0..second_half_bit) - .map(|i| { - let p1 = p_evals.evaluation(i | second_half_bit); - let q0 = q_evals.evaluation(i); - p1 * q0 - }) - .sum(); - - let sum_0 = acc00; - let sum_1 = acc11; - let mut sum_half = acc00 + acc11 + acc01 + acc10; - sum_half *= inverse_four; - - (sum_0, sum_1, sum_half) -} diff --git a/src/multilinear_product/sumcheck.rs b/src/multilinear_product/sumcheck.rs deleted file mode 100644 index a7493c7f..00000000 --- a/src/multilinear_product/sumcheck.rs +++ /dev/null @@ -1,73 +0,0 @@ -use ark_ff::Field; -use ark_std::{rand::Rng, vec::Vec}; - -use crate::{ - interpolation::LagrangePolynomial, order_strategy::GraycodeOrder, prover::Prover, - streams::Stream, -}; - -#[derive(Debug, PartialEq)] -pub struct ProductSumcheck { - pub prover_messages: Vec<(F, F, F)>, - pub verifier_messages: Vec, -} - -impl ProductSumcheck { - pub fn prove(prover: &mut P, rng: &mut impl Rng) -> Self - where - S: Stream, - P: Prover, ProverMessage = Option<(F, F, F)>>, - { - // Initialize vectors to store prover and verifier messages - let mut prover_messages: Vec<(F, F, F)> = vec![]; - let mut verifier_messages: Vec = vec![]; - - // Run the protocol - let mut verifier_message: Option = None; - while let Some(message) = prover.next_message(verifier_message) { - let round_sum = message.0 + message.1; - let is_round_accepted = match verifier_message { - // If first round, compare to claimed_sum - None => true, // TODO (z-tech): give option to provide claim round_sum == prover.claim(), - Some(prev_verifier_message) => { - verifier_messages.push(prev_verifier_message); - let prev_prover_message = prover_messages.last().unwrap(); - round_sum - == LagrangePolynomial::::evaluate_from_three_points( - prev_verifier_message, - *prev_prover_message, - ) - } - }; - - // Handle how to proceed - prover_messages.push(message); - if !is_round_accepted { - break; - } - - verifier_message = Some(F::rand(rng)); - } - - // Return a Sumcheck struct with the collected messages and acceptance status - ProductSumcheck { - prover_messages, - verifier_messages, - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - multilinear_product::TimeProductProver, - tests::{multilinear_product::consistency_test, BenchStream, F64}, - }; - - #[test] - fn algorithm_consistency() { - consistency_test::, TimeProductProver>>(); - // should take ordering of the stream - // consistency_test::, BlendyProductProver>>(); - } -} diff --git a/src/multilinear_sumcheck.rs b/src/multilinear_sumcheck.rs index 09097a84..8134c20e 100644 --- a/src/multilinear_sumcheck.rs +++ b/src/multilinear_sumcheck.rs @@ -1,141 +1,331 @@ -//! Standard multilinear sumcheck protocol. +//! Standard multilinear sumcheck: `∑_x v(x)`. //! -//! Given evaluations `[p(0..0), p(0..1), ..., p(1..1)]` of a multilinear polynomial `p` -//! on the boolean hypercube `{0,1}^n`, the [`multilinear_sumcheck`] function executes `n` -//! rounds of the sumcheck protocol and returns the resulting [`Sumcheck`] transcript. +//! Half-split (MSB) layout with a fused fold+compute kernel. Round `i` +//! folds the top-most remaining variable — the round-0 split is +//! `v[0..L/2]` vs `v[L/2..L]`, *not* the adjacent pairs `(v[2k], v[2k+1])` +//! of a pair-split (LSB) layout. Callers whose upstream indexing assumed +//! pair-split semantics must reorder their inputs with a bit-reversal. //! -//! The function is parameterized by two field types: -//! - `BF` (base field): the field the evaluations live in -//! - `EF` (extension field): the field challenges are sampled from +//! Wire format per round: `(s0, s1)` where +//! - `s0 = q(0) = Σ v_lo` +//! - `s1 = q(1) = Σ v_hi` //! -//! When no extension field is needed, set `EF = BF`. +//! The round polynomial is degree 1: `q(X) = s0 + X·(s1 − s0)`. Consistency +//! invariant: `s0 + s1 == current_claim`. //! -//! # Example -//! -//! ```text -//! use efficient_sumcheck::{multilinear_sumcheck, Sumcheck}; -//! use efficient_sumcheck::transcript::SanityTranscript; -//! -//! // No extension field (BF = EF): -//! let mut evals = vec![F::from(1), F::from(2), F::from(3), F::from(4)]; -//! let mut transcript = SanityTranscript::new(&mut rng); -//! let result: Sumcheck = multilinear_sumcheck(&mut evals, &mut transcript); -//! ``` +//! The fused kernel rolls the round-`i` fold into the round-`(i+1)` compute: +//! 4 reads + 2 writes per quadruple (fused) vs. 6 reads + 2 writes +//! (fold + compute separately) — a ~33% memory-traffic reduction. use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::join; +#[cfg(feature = "parallel")] +use rayon::prelude::*; -use crate::multilinear::reductions::pairwise; -use crate::transcript::Transcript; +use crate::transcript::ProverTranscript; -pub use crate::multilinear::Sumcheck; +/// Legacy return type for `multilinear_sumcheck`. +#[derive(Debug)] +pub struct Sumcheck { + pub prover_messages: Vec<(F, F)>, + pub verifier_messages: Vec, + pub final_evaluation: F, +} -/// Run the standard multilinear sumcheck protocol over an evaluation vector, -/// using a generic [`Transcript`] for Fiat-Shamir (or sanity/random challenges). -/// -/// `BF` is the base field of the evaluations, `EF` is the extension field for challenges. -/// When `BF = EF`, this is the standard single-field sumcheck. -/// When `BF ≠ EF`, round 0 evaluates in `BF` and lifts to `EF`, then subsequent -/// rounds work entirely in `EF`. -/// -/// Each round: -/// 1. Computes the round polynomial evaluations `(s(0), s(1))` via pairwise reduction. -/// 2. Writes them to the transcript (2 field elements). -/// 3. Reads the verifier's challenge from the transcript (1 field element). -/// 4. Reduces the evaluation vector by folding with the challenge. -pub fn multilinear_sumcheck>( - evaluations: &mut [BF], - transcript: &mut impl Transcript, -) -> Sumcheck { - // checks - assert!( - evaluations.len().count_ones() == 1, - "length must be a power of 2" - ); - assert!(evaluations.len() >= 2, "need at least 1 variable"); +// ─── Workload threshold ───────────────────────────────────────────────────── - let num_rounds = evaluations.len().trailing_zeros() as usize; - let mut prover_messages: Vec<(EF, EF)> = vec![]; - let mut verifier_messages: Vec = vec![]; +const fn workload_size() -> usize { + #[cfg(all(target_arch = "aarch64", target_os = "macos"))] + const CACHE_SIZE: usize = 1 << 17; + #[cfg(all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ))] + const CACHE_SIZE: usize = 1 << 16; + #[cfg(target_arch = "x86_64")] + const CACHE_SIZE: usize = 1 << 15; + #[cfg(not(any( + all(target_arch = "aarch64", target_os = "macos"), + all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ), + target_arch = "x86_64" + )))] + const CACHE_SIZE: usize = 1 << 15; - // ── Round 0: evaluate in BF, lift to EF, cross-field reduce ── - if num_rounds > 0 { - let msg_bf = pairwise::evaluate(evaluations); - let msg = (EF::from(msg_bf.0), EF::from(msg_bf.1)); + CACHE_SIZE / core::mem::size_of::() +} + +// ─── Scalar helpers ───────────────────────────────────────────────────────── + +fn sum_slice(v: &[F]) -> F { + #[cfg(feature = "parallel")] + if v.len() > workload_size::() { + return v.par_iter().copied().sum(); + } + v.iter().copied().sum() +} + +fn scalar_mul(v: &mut [F], w: F) { + for x in v.iter_mut() { + *x *= w; + } +} - prover_messages.push(msg); - transcript.write(msg.0); - transcript.write(msg.1); +// ─── Core algebra ─────────────────────────────────────────────────────────── - let chg = transcript.read(); - verifier_messages.push(chg); +/// `(s0, s1)` of the degree-1 round polynomial `q(X) = s0 + X·(s1 − s0)`. +/// +/// `values` is implicitly zero-extended to the next power of two. +/// - `s0 = Σ v[0..L/2]` (low half, possibly with tail contributions) +/// - `s1 = Σ v[L/2..L]` +pub fn compute_sumcheck_polynomial(values: &[F]) -> (F, F) { + fn recurse(lo: &[F], hi: &[F]) -> (F, F) { + debug_assert_eq!(lo.len(), hi.len()); - // Cross-field reduce: BF evaluations + EF challenge → Vec - let mut ef_evals = pairwise::cross_field_reduce(evaluations, chg); + #[cfg(feature = "parallel")] + if lo.len() * 2 > workload_size::() { + let mid = lo.len() / 2; + let (lol, lor) = lo.split_at(mid); + let (hil, hir) = hi.split_at(mid); + let (l, r) = join(|| recurse(lol, hil), || recurse(lor, hir)); + return (l.0 + r.0, l.1 + r.1); + } + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for (&l, &h) in lo.iter().zip(hi) { + s0 += l; + s1 += h; + } + (s0, s1) + } - // Remaining rounds work in EF - for _ in 1..num_rounds { - let msg = pairwise::evaluate(&ef_evals); + if values.is_empty() { + return (F::ZERO, F::ZERO); + } + if values.len() == 1 { + // Implicit zero pad on the high half: (v[0], 0). + return (values[0], F::ZERO); + } - prover_messages.push(msg); - transcript.write(msg.0); - transcript.write(msg.1); + let half = values.len().next_power_of_two() >> 1; + let (lo, hi) = values.split_at(half); + debug_assert!(lo.len() >= hi.len()); + let (lo, lo_tail) = lo.split_at(hi.len()); + let (s0, s1) = recurse(lo, hi); - let chg = transcript.read(); - verifier_messages.push(chg); + // Tail (hi implicitly zero): contributes to s0 only. + let tail = sum_slice(lo_tail); + (s0 + tail, s1) +} - pairwise::reduce_evaluations(&mut ef_evals, chg); +/// In-place half-split (MSB) fold: `new[k] = v[k] + (v[k+L/2] − v[k]) · weight`. +/// +/// Implicit zero padding on the high half collapses the tail to `v[k] * (1 − w)`. +/// +/// SIMD-accelerated for Goldilocks base field on NEON and AVX-512 IFMA. +/// Falls back to a scalar recursive `rayon::join` fold for other fields. +pub fn fold(values: &mut Vec, weight: F) { + // SIMD fast path for base-field Goldilocks (MSB layout). + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + { + if crate::simd_sumcheck::dispatch::try_simd_reduce_msb(values, weight) { + values.shrink_to_fit(); + return; + } + } + fn recurse_both(low: &mut [F], high: &[F], weight: F) { + #[cfg(feature = "parallel")] + if low.len() > workload_size::() { + let split = low.len() / 2; + let (ll, lr) = low.split_at_mut(split); + let (hl, hr) = high.split_at(split); + join( + || recurse_both(ll, hl, weight), + || recurse_both(lr, hr, weight), + ); + return; + } + for (low, high) in low.iter_mut().zip(high) { + *low += (*high - *low) * weight; } } - Sumcheck { - verifier_messages, - prover_messages, + if values.len() <= 1 { + return; } + + let half = values.len().next_power_of_two() >> 1; + let (low, high) = values.split_at_mut(half); + debug_assert!(low.len() >= high.len()); + let (low, tail) = low.split_at_mut(high.len()); + recurse_both(low, high, weight); + + scalar_mul(tail, F::ONE - weight); + + values.truncate(half); + values.shrink_to_fit(); } -#[cfg(test)] -mod tests { - use super::*; - use ark_ff::UniformRand; - use ark_std::test_rng; +/// Two-pass fold-then-compute. Reference only. +pub fn fold_and_compute_polynomial(values: &mut Vec, weight: F) -> (F, F) { + fold(values, weight); + compute_sumcheck_polynomial(values) +} - use crate::tests::F64; +/// Fused fold + compute: folds `values` by `weight` *and* returns the +/// next-round `(s0, s1)` in one sweep over the quadruple +/// `(v[k], v[k+L/4], v[k+L/2], v[k+3L/4])`. +pub fn fused_fold_and_compute_polynomial(values: &mut Vec, weight: F) -> (F, F) { + let l = values.len(); + if !l.is_power_of_two() || l < 4 { + return fold_and_compute_polynomial(values, weight); + } - const NUM_VARS: usize = 4; // vectors of length 2^4 = 16 + fn kernel(v0: &mut [F], v1: &mut [F], v2: &[F], v3: &[F], weight: F) -> (F, F) { + debug_assert_eq!(v0.len(), v1.len()); + debug_assert_eq!(v0.len(), v2.len()); + debug_assert_eq!(v0.len(), v3.len()); - #[test] - fn test_multilinear_sumcheck_sanity() { - use crate::transcript::SanityTranscript; + #[cfg(feature = "parallel")] + if v0.len() * 2 > workload_size::() { + let mid = v0.len() / 2; + let (v0l, v0r) = v0.split_at_mut(mid); + let (v1l, v1r) = v1.split_at_mut(mid); + let (v2l, v2r) = v2.split_at(mid); + let (v3l, v3r) = v3.split_at(mid); + let (left, right) = join( + || kernel(v0l, v1l, v2l, v3l, weight), + || kernel(v0r, v1r, v2r, v3r, weight), + ); + return (left.0 + right.0, left.1 + right.1); + } - let mut rng = test_rng(); + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for i in 0..v0.len() { + let x0 = v0[i]; + let x1 = v1[i]; + let x2 = v2[i]; + let x3 = v3[i]; - let n = 1 << NUM_VARS; - let mut evaluations: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let n_lo = x0 + (x2 - x0) * weight; + let n_hi = x1 + (x3 - x1) * weight; - let mut transcript = SanityTranscript::new(&mut rng); - let result = multilinear_sumcheck::(&mut evaluations, &mut transcript); + v0[i] = n_lo; + v1[i] = n_hi; - assert_eq!(result.prover_messages.len(), NUM_VARS); - assert_eq!(result.verifier_messages.len(), NUM_VARS); + s0 += n_lo; + s1 += n_hi; + } + (s0, s1) } - #[test] - fn test_multilinear_sumcheck_spongefish() { - use crate::transcript::SpongefishTranscript; + let quarter = l / 4; + let half = l / 2; + + let (first, second) = values.split_at_mut(half); + let (v0, v1) = first.split_at_mut(quarter); + let (v2, v3) = second.split_at(quarter); - let mut rng = test_rng(); + let result = kernel(v0, v1, v2, v3, weight); - let n = 1 << NUM_VARS; - let mut evaluations: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + values.truncate(half); + result +} + +// ─── Prover ───────────────────────────────────────────────────────────────── + +/// Runs `num_rounds` rounds on `values`, folding it in place. +/// +/// Transcript per round: writes `s0` then `s1`, invokes +/// `hook(round, transcript)`, then reads the verifier challenge. +/// +/// On return, if `num_rounds == log2(next_pow2(len))` then `values.len() == 1` +/// and `final_evaluation = values[0]`; otherwise `F::ZERO`. +pub fn multilinear_sumcheck_partial( + values: &mut Vec, + transcript: &mut T, + num_rounds: usize, + mut hook: H, +) -> Sumcheck +where + F: Field, + T: ProverTranscript, + H: FnMut(usize, &mut T), +{ + assert!( + num_rounds == 0 || values.len().next_power_of_two() >= 1 << num_rounds, + "num_rounds ({num_rounds}) exceeds log2 of next-pow2 of len ({})", + values.len(), + ); - let domsep = spongefish::domain_separator!("test-multilinear-sumcheck"; module_path!()) - .instance(b"test"); + let mut prover_messages: Vec<(F, F)> = Vec::with_capacity(num_rounds); + let mut verifier_messages: Vec = Vec::with_capacity(num_rounds); + let mut folding_randomness: Option = None; - let prover_state = domsep.std_prover(); - let mut transcript = SpongefishTranscript::new(prover_state); - let result = multilinear_sumcheck::(&mut evaluations, &mut transcript); + for round in 0..num_rounds { + let (s0, s1) = if let Some(w) = folding_randomness { + fused_fold_and_compute_polynomial(values, w) + } else { + compute_sumcheck_polynomial(values) + }; - assert_eq!(result.prover_messages.len(), NUM_VARS); - assert_eq!(result.verifier_messages.len(), NUM_VARS); + prover_messages.push((s0, s1)); + transcript.send(s0); + transcript.send(s1); + + hook(round, transcript); + + let r = transcript.challenge(); + verifier_messages.push(r); + folding_randomness = Some(r); + } + + if let Some(w) = folding_randomness { + fold(values, w); + } + + let final_evaluation = if values.len() == 1 { + values[0] + } else { + F::ZERO + }; + + Sumcheck { + prover_messages, + verifier_messages, + final_evaluation, } } + +/// Full sumcheck (`log2(next_pow2(len))` rounds) with a per-round hook. +pub fn multilinear_sumcheck( + values: &mut Vec, + transcript: &mut T, + hook: H, +) -> Sumcheck +where + F: Field, + T: ProverTranscript, + H: FnMut(usize, &mut T), +{ + let num_rounds = if values.is_empty() { + 0 + } else { + values.len().next_power_of_two().trailing_zeros() as usize + }; + multilinear_sumcheck_partial(values, transcript, num_rounds, hook) +} + +// ─── Verifier ─────────────────────────────────────────────────────────────── + +// Tests live in `tests/multilinear_sumcheck.rs` (integration target). diff --git a/src/order_strategy/ascending/ascending.rs b/src/order_strategy/ascending/ascending.rs deleted file mode 100644 index 3dfea44f..00000000 --- a/src/order_strategy/ascending/ascending.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::{hypercube::Hypercube, order_strategy::OrderStrategy}; - -pub struct AscendingOrder { - current_index: usize, - stop_value: usize, // exclusive - num_vars: usize, -} - -impl OrderStrategy for AscendingOrder { - fn new(num_vars: usize) -> Self { - Self { - current_index: 0, - stop_value: Hypercube::::stop_value(num_vars), // exclusive - num_vars, - } - } - - fn next_index(&mut self) -> Option { - if self.current_index < self.stop_value { - let this_index = Some(self.current_index); - self.current_index += 1; - this_index - } else { - None - } - } - - fn num_vars(&self) -> usize { - self.num_vars - } -} - -impl Iterator for AscendingOrder { - type Item = usize; - - fn next(&mut self) -> Option { - self.next_index() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanity() { - let order_0 = AscendingOrder::new(0); - let indices_0: Vec = order_0.collect(); - assert_eq!(indices_0, vec![0]); - - let order_1 = AscendingOrder::new(1); - let indices_1: Vec = order_1.collect(); - assert_eq!(indices_1, vec![0, 1]); - - let order_2 = AscendingOrder::new(2); - let indices_2: Vec = order_2.collect(); - assert_eq!(indices_2, vec![0, 1, 2, 3]); - - let order_3 = AscendingOrder::new(3); - let indices_3: Vec = order_3.collect(); - assert_eq!(indices_3, vec![0, 1, 2, 3, 4, 5, 6, 7]); - } -} diff --git a/src/order_strategy/ascending/mod.rs b/src/order_strategy/ascending/mod.rs deleted file mode 100644 index 9ac41eb3..00000000 --- a/src/order_strategy/ascending/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[allow(clippy::module_inception)] -mod ascending; - -pub use ascending::AscendingOrder; diff --git a/src/order_strategy/core.rs b/src/order_strategy/core.rs deleted file mode 100644 index 981efd44..00000000 --- a/src/order_strategy/core.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub trait OrderStrategy: Iterator { - fn new(num_variables: usize) -> Self; - fn next_index(&mut self) -> Option; - fn num_vars(&self) -> usize; -} diff --git a/src/order_strategy/descending/descending.rs b/src/order_strategy/descending/descending.rs deleted file mode 100644 index ad2b6b94..00000000 --- a/src/order_strategy/descending/descending.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::{hypercube::Hypercube, order_strategy::OrderStrategy}; - -#[derive(PartialEq, Debug)] -pub struct DescendingOrder { - current_index: isize, - stop_value: isize, - num_vars: usize, -} - -impl OrderStrategy for DescendingOrder { - fn new(num_vars: usize) -> Self { - Self { - current_index: Hypercube::::stop_value(num_vars) as isize - 1, - stop_value: 0, - num_vars, - } - } - - fn next_index(&mut self) -> Option { - if self.current_index < self.stop_value { - return None; - } - - let this_index = self.current_index as usize; - self.current_index -= 1; - Some(this_index) - } - - fn num_vars(&self) -> usize { - self.num_vars - } -} - -impl Iterator for DescendingOrder { - type Item = usize; - - fn next(&mut self) -> Option { - self.next_index() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanity() { - let order_0 = DescendingOrder::new(0); - let indices_0: Vec = order_0.collect(); - assert_eq!(indices_0, vec![0]); - - let order_1 = DescendingOrder::new(1); - let indices_1: Vec = order_1.collect(); - assert_eq!(indices_1, vec![1, 0]); - - let order_2 = DescendingOrder::new(2); - let indices_2: Vec = order_2.collect(); - assert_eq!(indices_2, vec![3, 2, 1, 0]); - - let order_3 = DescendingOrder::new(3); - let indices_3: Vec = order_3.collect(); - assert_eq!(indices_3, vec![7, 6, 5, 4, 3, 2, 1, 0]); - } -} diff --git a/src/order_strategy/descending/mod.rs b/src/order_strategy/descending/mod.rs deleted file mode 100644 index 53b0023e..00000000 --- a/src/order_strategy/descending/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[allow(clippy::module_inception)] -mod descending; - -pub use descending::DescendingOrder; diff --git a/src/order_strategy/graycode/graycode.rs b/src/order_strategy/graycode/graycode.rs deleted file mode 100644 index 6f82a253..00000000 --- a/src/order_strategy/graycode/graycode.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::{hypercube::Hypercube, order_strategy::OrderStrategy}; - -pub struct GraycodeOrder { - current_index: usize, - stop_value: usize, // exclusive - num_vars: usize, -} - -impl GraycodeOrder { - pub fn next_gray_code(value: usize) -> usize { - let mask = match value.count_ones() & 1 == 0 { - true => 1, - false => 1 << (value.trailing_zeros() + 1), - }; - value ^ mask - } -} - -impl OrderStrategy for GraycodeOrder { - fn new(num_vars: usize) -> Self { - Self { - current_index: 0, - stop_value: Hypercube::::stop_value(num_vars), // exclusive - num_vars, - } - } - - fn next_index(&mut self) -> Option { - if self.current_index < self.stop_value { - let this_index = Some(self.current_index); - self.current_index = GraycodeOrder::next_gray_code(self.current_index); - this_index - } else { - None - } - } - - fn num_vars(&self) -> usize { - self.num_vars - } -} - -impl Iterator for GraycodeOrder { - type Item = usize; - - fn next(&mut self) -> Option { - self.next_index() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanity() { - // https://docs.rs/gray-codes/latest/gray_codes/struct.GrayCode8.html#examples - let order_0 = GraycodeOrder::new(0); - let indices_0: Vec = order_0.collect(); - assert_eq!(indices_0, vec![0]); - - let order_1 = GraycodeOrder::new(1); - let indices_1: Vec = order_1.collect(); - assert_eq!(indices_1, vec![0, 1]); - - let order_2 = GraycodeOrder::new(2); - let indices_2: Vec = order_2.collect(); - assert_eq!(indices_2, vec![0, 1, 3, 2]); - - let order_3 = GraycodeOrder::new(3); - let indices_3: Vec = order_3.collect(); - assert_eq!(indices_3, vec![0, 1, 3, 2, 6, 7, 5, 4]); - } -} diff --git a/src/order_strategy/graycode/mod.rs b/src/order_strategy/graycode/mod.rs deleted file mode 100644 index 92f5ded0..00000000 --- a/src/order_strategy/graycode/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -#[allow(clippy::module_inception)] -mod graycode; -pub use graycode::GraycodeOrder; diff --git a/src/order_strategy/mod.rs b/src/order_strategy/mod.rs deleted file mode 100644 index 8ff1dd52..00000000 --- a/src/order_strategy/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod ascending; -mod core; -mod descending; -mod graycode; -mod msb; - -pub use ascending::AscendingOrder; -pub use core::OrderStrategy; -pub use descending::DescendingOrder; -pub use graycode::GraycodeOrder; -pub use msb::MSBOrder; diff --git a/src/order_strategy/msb/mod.rs b/src/order_strategy/msb/mod.rs deleted file mode 100644 index 674a8e7d..00000000 --- a/src/order_strategy/msb/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[allow(clippy::module_inception)] -mod msb; - -pub use msb::MSBOrder; diff --git a/src/order_strategy/msb/msb.rs b/src/order_strategy/msb/msb.rs deleted file mode 100644 index 9ea14853..00000000 --- a/src/order_strategy/msb/msb.rs +++ /dev/null @@ -1,83 +0,0 @@ -use crate::{hypercube::Hypercube, order_strategy::OrderStrategy}; - -pub struct MSBOrder { - current_index: usize, - stop_value: usize, // exclusive - num_vars: usize, -} - -// we're using the usize like a vec, so we can't just reverse the whole thing .reverse_bits() -impl MSBOrder { - pub fn next_value_in_msb_order(x: usize, n: u32) -> usize { - let mut result = x; - for i in (0..n).rev() { - result ^= 1 << i; - if result >> i == 1 { - break; - } - } - result - } -} - -impl OrderStrategy for MSBOrder { - fn new(num_vars: usize) -> Self { - Self { - current_index: 0, - stop_value: Hypercube::::stop_value(num_vars), // exclusive - num_vars, - } - } - - fn next_index(&mut self) -> Option { - if self.current_index < self.stop_value { - let old_index = self.current_index; - self.current_index = - MSBOrder::next_value_in_msb_order(self.current_index, self.num_vars as u32); - if self.current_index == 0 { - // if the sequence rounds back to 0, we need to stop - self.current_index = self.stop_value; - } - Some(old_index) - } else { - None - } - } - - fn num_vars(&self) -> usize { - self.num_vars - } -} - -impl Iterator for MSBOrder { - type Item = usize; - - fn next(&mut self) -> Option { - self.next_index() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanity() { - // https://docs.rs/gray-codes/latest/gray_codes/struct.GrayCode8.html#examples - let order_0 = MSBOrder::new(0); - let indices_0: Vec = order_0.collect(); - assert_eq!(indices_0, vec![0]); - - let order_1 = MSBOrder::new(1); - let indices_1: Vec = order_1.collect(); - assert_eq!(indices_1, vec![0, 1]); - - let order_2 = MSBOrder::new(2); - let indices_2: Vec = order_2.collect(); - assert_eq!(indices_2, vec![0, 2, 1, 3]); - - let order_3 = MSBOrder::new(3); - let indices_3: Vec = order_3.collect(); - assert_eq!(indices_3, vec![0, 4, 2, 6, 1, 5, 3, 7]); - } -} diff --git a/src/poly_ops.rs b/src/poly_ops.rs new file mode 100644 index 00000000..1d359dc6 --- /dev/null +++ b/src/poly_ops.rs @@ -0,0 +1,257 @@ +//! Zero-allocation polynomial arithmetic on coefficient slices. +//! +//! All functions operate on `&[F]` or `&mut [F]` in ascending degree order +//! (same layout as `DensePolynomial::coeffs`). The caller owns the memory — +//! stack arrays, pre-allocated buffers, or flat fold buffers all work. +//! +//! Designed to eventually upstream into `ark-poly::DensePolynomial` as +//! in-place methods. + +use ark_ff::Field; +use ark_poly::univariate::DensePolynomial; + +/// Schoolbook polynomial multiplication: `out = a * b`. +/// +/// `out` must have length ≥ `a.len() + b.len() - 1`. +/// Zeroes `out` before writing. +/// +/// # Panics +/// +/// Panics if `out` is too short, or if either input is empty. +#[inline] +pub fn mul_into(out: &mut [F], a: &[F], b: &[F]) { + let n = a.len() + b.len() - 1; + debug_assert!( + out.len() >= n, + "out.len()={} but need {} for deg {} × deg {}", + out.len(), + n, + a.len() - 1, + b.len() - 1 + ); + for o in out[..n].iter_mut() { + *o = F::ZERO; + } + for (i, &ai) in a.iter().enumerate() { + if ai.is_zero() { + continue; + } + for (j, &bj) in b.iter().enumerate() { + out[i + j] += ai * bj; + } + } +} + +/// Fused multiply-accumulate: `out += a * b`. +/// +/// `out` must have length ≥ `a.len() + b.len() - 1`. +/// Does NOT zero `out` — accumulates into existing values. +#[inline] +pub fn mul_add_into(out: &mut [F], a: &[F], b: &[F]) { + let n = a.len() + b.len() - 1; + debug_assert!(out.len() >= n); + for (i, &ai) in a.iter().enumerate() { + if ai.is_zero() { + continue; + } + for (j, &bj) in b.iter().enumerate() { + out[i + j] += ai * bj; + } + } +} + +/// In-place addition: `a += b`. +/// +/// `a` must have length ≥ `b.len()`. +#[inline] +pub fn add_assign(a: &mut [F], b: &[F]) { + debug_assert!(a.len() >= b.len()); + for (ai, &bi) in a.iter_mut().zip(b) { + *ai += bi; + } +} + +/// In-place subtraction: `a -= b`. +/// +/// `a` must have length ≥ `b.len()`. +#[inline] +pub fn sub_assign(a: &mut [F], b: &[F]) { + debug_assert!(a.len() >= b.len()); + for (ai, &bi) in a.iter_mut().zip(b) { + *ai -= bi; + } +} + +/// Subtraction into buffer: `out = a - b`. +/// +/// `out` must have length ≥ `max(a.len(), b.len())`. +#[inline] +pub fn sub_into(out: &mut [F], a: &[F], b: &[F]) { + let n = a.len().max(b.len()); + debug_assert!(out.len() >= n); + for i in 0..n { + let ai = if i < a.len() { a[i] } else { F::ZERO }; + let bi = if i < b.len() { b[i] } else { F::ZERO }; + out[i] = ai - bi; + } +} + +/// Fused scale-and-add: `a += s * b`. +/// +/// `a` must have length ≥ `b.len()`. +#[inline] +pub fn add_scaled(a: &mut [F], s: F, b: &[F]) { + debug_assert!(a.len() >= b.len()); + if s.is_zero() { + return; + } + if s.is_one() { + add_assign(a, b); + return; + } + for (ai, &bi) in a.iter_mut().zip(b) { + *ai += s * bi; + } +} + +/// In-place scaling: `a *= s`. +#[inline] +pub fn scale(a: &mut [F], s: F) { + for ai in a.iter_mut() { + *ai *= s; + } +} + +/// Evaluate polynomial at `x` via Horner's method. +/// +/// `coeffs[0] + coeffs[1]*x + coeffs[2]*x² + ...` +#[inline] +pub fn eval_at(coeffs: &[F], x: F) -> F { + if coeffs.is_empty() { + return F::ZERO; + } + let mut result = *coeffs.last().unwrap(); + for &c in coeffs.iter().rev().skip(1) { + result = result * x + c; + } + result +} + +/// Copy coefficients: `dst[..src.len()] = src`. +#[inline] +pub fn copy_into(dst: &mut [F], src: &[F]) { + debug_assert!(dst.len() >= src.len()); + dst[..src.len()].copy_from_slice(src); +} + +/// Zero a coefficient buffer. +#[inline] +pub fn zero(buf: &mut [F]) { + for b in buf.iter_mut() { + *b = F::ZERO; + } +} + +/// Convert a coefficient slice to `DensePolynomial`. +/// +/// This is the ONE place that allocates — use at the end when you need +/// to return a `DensePolynomial` to arkworks APIs. +pub fn to_dense_poly(coeffs: &[F]) -> DensePolynomial { + let mut v = coeffs.to_vec(); + // Trim trailing zeros (DensePolynomial invariant) + while v.last() == Some(&F::ZERO) && v.len() > 1 { + v.pop(); + } + DensePolynomial { coeffs: v } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::F64; + use ark_ff::{UniformRand, Zero}; + use ark_poly::{DenseUVPolynomial, Polynomial}; + use ark_std::{rand::RngCore, test_rng}; + + #[test] + fn test_mul_into_matches_naive_mul() { + let mut rng = test_rng(); + for _ in 0..100 { + let deg_a = (rng.next_u32() % 8) as usize; + let deg_b = (rng.next_u32() % 8) as usize; + let a: Vec = (0..=deg_a).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..=deg_b).map(|_| F64::rand(&mut rng)).collect(); + + let expected = DensePolynomial::from_coefficients_vec(a.clone()) + .naive_mul(&DensePolynomial::from_coefficients_vec(b.clone())); + + let mut out = vec![F64::zero(); a.len() + b.len() - 1]; + mul_into(&mut out, &a, &b); + + for (i, (&e, &o)) in expected.coeffs.iter().zip(out.iter()).enumerate() { + assert_eq!(e, o, "mul_into mismatch at coeff {i}"); + } + } + } + + #[test] + fn test_mul_add_into_accumulates() { + let a = [F64::from(1u64), F64::from(2u64)]; // 1 + 2x + let b = [F64::from(3u64), F64::from(4u64)]; // 3 + 4x + // a*b = 3 + 10x + 8x² + + let mut out = [F64::from(10u64), F64::zero(), F64::zero()]; // start with 10 + mul_add_into(&mut out, &a, &b); + // out should be [13, 10, 8] + assert_eq!(out[0], F64::from(13u64)); + assert_eq!(out[1], F64::from(10u64)); + assert_eq!(out[2], F64::from(8u64)); + } + + #[test] + fn test_add_scaled() { + let mut a = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + let b = [F64::from(10u64), F64::from(20u64)]; + let s = F64::from(5u64); + + add_scaled(&mut a, s, &b); + // a = [1+50, 2+100, 3] = [51, 102, 3] + assert_eq!(a[0], F64::from(51u64)); + assert_eq!(a[1], F64::from(102u64)); + assert_eq!(a[2], F64::from(3u64)); + } + + #[test] + fn test_eval_at_matches_polynomial() { + let mut rng = test_rng(); + for _ in 0..100 { + let deg = (rng.next_u32() % 10) as usize; + let coeffs: Vec = (0..=deg).map(|_| F64::rand(&mut rng)).collect(); + let x = F64::rand(&mut rng); + + let expected = DensePolynomial::from_coefficients_vec(coeffs.clone()).evaluate(&x); + let got = eval_at(&coeffs, x); + assert_eq!(expected, got); + } + } + + #[test] + fn test_sub_into() { + let a = [F64::from(10u64), F64::from(20u64), F64::from(30u64)]; + let b = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + let mut out = [F64::zero(); 3]; + sub_into(&mut out, &a, &b); + assert_eq!(out[0], F64::from(9u64)); + assert_eq!(out[1], F64::from(18u64)); + assert_eq!(out[2], F64::from(27u64)); + } + + #[test] + fn test_to_dense_poly_trims_zeros() { + let coeffs = [F64::from(1u64), F64::from(2u64), F64::zero(), F64::zero()]; + let p = to_dense_poly(&coeffs); + assert_eq!(p.coeffs.len(), 2); + assert_eq!(p.coeffs[0], F64::from(1u64)); + assert_eq!(p.coeffs[1], F64::from(2u64)); + } +} diff --git a/src/polynomial/dense.rs b/src/polynomial/dense.rs new file mode 100644 index 00000000..1ee51ca0 --- /dev/null +++ b/src/polynomial/dense.rs @@ -0,0 +1,97 @@ +//! Zero-allocation dense polynomial arithmetic on coefficient slices. + +use crate::field::SumcheckField; + +/// Multiply polynomials `a` and `b`, writing the result into `out`. +/// +/// `out` must have length `>= a.len() + b.len() - 1`. Existing contents +/// are overwritten. No heap allocation. +/// +/// `a = [a_0, a_1, ..., a_m]`, `b = [b_0, b_1, ..., b_n]`. +/// `out[k] = Σ_{i+j=k} a_i · b_j`. +pub fn mul_into(out: &mut [F], a: &[F], b: &[F]) { + if a.is_empty() || b.is_empty() { + for o in out.iter_mut() { + *o = F::ZERO; + } + return; + } + let result_len = a.len() + b.len() - 1; + debug_assert!(out.len() >= result_len); + + for o in out.iter_mut() { + *o = F::ZERO; + } + for (i, &ai) in a.iter().enumerate() { + if ai.is_zero() { + continue; + } + for (j, &bj) in b.iter().enumerate() { + out[i + j] += ai * bj; + } + } +} + +/// Add `scalar * p` to `out` in place: `out[i] += scalar · p[i]`. +/// +/// If `p` is longer than `out`, the extra terms are ignored. +/// No allocation. +pub fn add_scaled(out: &mut [F], scalar: F, p: &[F]) { + let len = out.len().min(p.len()); + for i in 0..len { + out[i] += scalar * p[i]; + } +} + +/// Evaluate polynomial from coefficients at `x` via Horner's method. +/// +/// Alias for [`eval_horner`](super::eval_horner). +#[inline] +pub fn eval_at(coeffs: &[F], x: F) -> F { + super::eval_horner(coeffs, x) +} + +#[cfg(test)] +#[cfg(feature = "arkworks")] +mod tests { + use super::*; + use crate::tests::F64; + + #[test] + fn mul_linear_times_linear() { + // (1 + 2x)(3 + 4x) = 3 + 10x + 8x² + let a = [F64::from(1u64), F64::from(2u64)]; + let b = [F64::from(3u64), F64::from(4u64)]; + let mut out = [F64::ZERO; 3]; + mul_into(&mut out, &a, &b); + assert_eq!(out[0], F64::from(3u64)); + assert_eq!(out[1], F64::from(10u64)); + assert_eq!(out[2], F64::from(8u64)); + } + + #[test] + fn mul_by_zero() { + let a = [F64::from(1u64), F64::from(2u64)]; + let b: [F64; 0] = []; + let mut out = [F64::from(99u64); 3]; + mul_into(&mut out, &a, &b); + assert!(out.iter().all(|&x| x == F64::ZERO)); + } + + #[test] + fn add_scaled_basic() { + let mut out = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + let p = [F64::from(10u64), F64::from(20u64)]; + add_scaled(&mut out, F64::from(3u64), &p); + assert_eq!(out[0], F64::from(31u64)); // 1 + 3*10 + assert_eq!(out[1], F64::from(62u64)); // 2 + 3*20 + assert_eq!(out[2], F64::from(3u64)); // unchanged + } + + #[test] + fn eval_at_matches_horner() { + let coeffs = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + let x = F64::from(4u64); + assert_eq!(eval_at(&coeffs, x), super::super::eval_horner(&coeffs, x)); + } +} diff --git a/src/polynomial/eval.rs b/src/polynomial/eval.rs new file mode 100644 index 00000000..989f007b --- /dev/null +++ b/src/polynomial/eval.rs @@ -0,0 +1,214 @@ +//! Polynomial evaluation: Horner's method and barycentric Lagrange interpolation. + +extern crate alloc; +use crate::field::SumcheckField; +use alloc::vec; +use alloc::vec::Vec; + +/// Evaluate a polynomial from its coefficients at point `x` via Horner's method. +/// +/// `coeffs = [c_0, c_1, ..., c_d]` represents `p(X) = c_0 + c_1·X + ... + c_d·X^d`. +/// +/// Cost: `d` multiplications + `d` additions. Zero allocation. +#[inline] +pub fn eval_horner(coeffs: &[F], x: F) -> F { + if coeffs.is_empty() { + return F::ZERO; + } + let mut result = coeffs[coeffs.len() - 1]; + for i in (0..coeffs.len() - 1).rev() { + result = result * x + coeffs[i]; + } + result +} + +/// Evaluate a polynomial from its evaluations at `{0, 1, ..., d}` at +/// an arbitrary point `x` via barycentric Lagrange interpolation. +/// +/// `evals = [p(0), p(1), ..., p(d)]`. +/// +/// Cost: O(d) with precomputed [`BarycentricWeights`], O(d²) without. +/// For repeated evaluations at the same degree, precompute weights once. +pub fn eval_from_evals(evals: &[F], x: F) -> F { + let d = evals.len(); + if d == 0 { + return F::ZERO; + } + if d == 1 { + return evals[0]; + } + if d == 2 { + // Linear: p(x) = p(0) + x·(p(1) − p(0)). + return evals[0] + x * (evals[1] - evals[0]); + } + BarycentricWeights::new(d - 1).eval(evals, x) +} + +/// Precomputed barycentric weights for Lagrange interpolation at `{0, 1, ..., d}`. +/// +/// Compute once per degree, reuse across rounds. The verifier calls this +/// once and evaluates O(d) per round instead of O(d²). +/// +/// Weight `w_i = 1 / Π_{j≠i} (i − j)` for `i, j ∈ {0, ..., d}`. +/// For consecutive integer nodes these are `(-1)^{d-i} / (i! · (d-i)!)`. +pub struct BarycentricWeights { + /// Precomputed `w_i` for each node `i ∈ {0, ..., d}`. + weights: Vec, +} + +impl BarycentricWeights { + /// Precompute weights for interpolation at `{0, 1, ..., degree}`. + pub fn new(degree: usize) -> Self { + let d = degree + 1; // number of nodes + let mut weights = Vec::with_capacity(d); + for i in 0..d { + let mut w = F::ONE; + for j in 0..d { + if j != i { + let diff = i as i64 - j as i64; + if diff > 0 { + w *= F::from_u64(diff as u64); + } else { + w *= -F::from_u64((-diff) as u64); + } + } + } + // w_i = 1 / Π_{j≠i} (i - j) + weights.push(w.inverse().unwrap_or(F::ZERO)); + } + Self { weights } + } + + /// Number of interpolation nodes (degree + 1). + pub fn num_nodes(&self) -> usize { + self.weights.len() + } + + /// Evaluate the interpolated polynomial at `x`. + /// + /// `evals` must have length `num_nodes()`. + /// + /// Uses the "first form" of the barycentric formula: + /// `p(x) = Σ_i w_i · L(x) / (x - i) · f(i)` + /// where `L(x) = Π_j (x - j)`. + /// + /// Cost: O(d) multiplications + O(d) additions. + pub fn eval(&self, evals: &[F], x: F) -> F { + let d = self.weights.len(); + debug_assert_eq!(evals.len(), d); + + // Check if x is one of the nodes (avoid division by zero). + for (i, &eval) in evals.iter().enumerate() { + let node = F::from_u64(i as u64); + if x == node { + return eval; + } + } + + // Compute (x - 0)(x - 1)...(x - d+1) via prefix/suffix products. + let x_minus: Vec = (0..d).map(|j| x - F::from_u64(j as u64)).collect(); + + let mut prefix = vec![F::ONE; d + 1]; + for i in 0..d { + prefix[i + 1] = prefix[i] * x_minus[i]; + } + let mut suffix = vec![F::ONE; d + 1]; + for i in (0..d).rev() { + suffix[i] = suffix[i + 1] * x_minus[i]; + } + + let mut result = F::ZERO; + for i in 0..d { + // numerator = Π_{j≠i} (x - j) = prefix[i] · suffix[i+1] + let numerator = prefix[i] * suffix[i + 1]; + result += evals[i] * numerator * self.weights[i]; + } + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // A simple field-like wrapper for f64 won't work with SumcheckField + // (needs Copy + all the ops). Tests use the arkworks F64 type. + #[cfg(feature = "arkworks")] + mod ark_tests { + use super::*; + use crate::tests::F64; + + #[test] + fn horner_constant() { + let coeffs = [F64::from(7u64)]; + assert_eq!(eval_horner(&coeffs, F64::from(42u64)), F64::from(7u64)); + } + + #[test] + fn horner_linear() { + // p(x) = 3 + 5x + let coeffs = [F64::from(3u64), F64::from(5u64)]; + // p(2) = 3 + 10 = 13 + assert_eq!(eval_horner(&coeffs, F64::from(2u64)), F64::from(13u64)); + } + + #[test] + fn horner_quadratic() { + // p(x) = 1 + 2x + 3x² + let coeffs = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + // p(4) = 1 + 8 + 48 = 57 + assert_eq!(eval_horner(&coeffs, F64::from(4u64)), F64::from(57u64)); + } + + #[test] + fn horner_empty() { + let coeffs: [F64; 0] = []; + assert_eq!(eval_horner(&coeffs, F64::from(5u64)), F64::ZERO); + } + + #[test] + fn barycentric_matches_horner() { + // p(x) = 1 + 2x + 3x² + // p(0) = 1, p(1) = 6, p(2) = 17 + let evals = [F64::from(1u64), F64::from(6u64), F64::from(17u64)]; + let coeffs = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + + // Evaluate at several points and compare. + for x_val in [0u64, 1, 2, 3, 5, 10, 100] { + let x = F64::from(x_val); + let from_coeffs = eval_horner(&coeffs, x); + let from_evals = eval_from_evals(&evals, x); + assert_eq!(from_coeffs, from_evals, "mismatch at x={x_val}"); + } + } + + #[test] + fn barycentric_linear() { + // p(0) = 3, p(1) = 7 → p(x) = 3 + 4x → p(5) = 23 + let evals = [F64::from(3u64), F64::from(7u64)]; + assert_eq!(eval_from_evals(&evals, F64::from(5u64)), F64::from(23u64)); + } + + #[test] + fn barycentric_at_nodes() { + let evals = [F64::from(10u64), F64::from(20u64), F64::from(30u64)]; + assert_eq!(eval_from_evals(&evals, F64::from(0u64)), F64::from(10u64)); + assert_eq!(eval_from_evals(&evals, F64::from(1u64)), F64::from(20u64)); + assert_eq!(eval_from_evals(&evals, F64::from(2u64)), F64::from(30u64)); + } + + #[test] + fn precomputed_weights_reuse() { + let evals = [F64::from(1u64), F64::from(6u64), F64::from(17u64)]; + let weights = BarycentricWeights::new(2); + + // Reuse weights for multiple evaluations. + let v3 = weights.eval(&evals, F64::from(3u64)); + let v5 = weights.eval(&evals, F64::from(5u64)); + + let coeffs = [F64::from(1u64), F64::from(2u64), F64::from(3u64)]; + assert_eq!(v3, eval_horner(&coeffs, F64::from(3u64))); + assert_eq!(v5, eval_horner(&coeffs, F64::from(5u64))); + } + } +} diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs new file mode 100644 index 00000000..f38de81a --- /dev/null +++ b/src/polynomial/mod.rs @@ -0,0 +1,30 @@ +//! Polynomial evaluation and arithmetic for sumcheck protocols. +//! +//! Generic over [`SumcheckField`](crate::field::SumcheckField) — available without the `arkworks` feature. +//! +//! # Evaluation +//! +//! - [`eval_horner`]: evaluate from coefficients at a point. O(d). +//! - [`eval_from_evals`]: evaluate from evaluations at `{0, 1, ..., d}` at +//! an arbitrary point via barycentric Lagrange interpolation. O(d). +//! - [`BarycentricWeights`]: precompute weights once per degree, reuse +//! across rounds for O(d) evaluation instead of O(d²). +//! +//! # Sequential Lagrange polynomial +//! +//! - [`SequentialLagrange`]: maintains `eq(r, x) = Π_j (r_j · x_j + (1-r_j)(1-x_j))` +//! incrementally as you iterate over the hypercube. Composes with +//! [`Ascending`](crate::hypercube::Ascending) for cache-friendly streaming. +//! +//! # Dense polynomial arithmetic +//! +//! Zero-allocation operations on coefficient slices: +//! - [`mul_into`], [`add_scaled`], [`eval_at`] (alias for `eval_horner`). + +mod dense; +mod eval; +mod sequential_lagrange; + +pub use dense::{add_scaled, eval_at, mul_into}; +pub use eval::{eval_from_evals, eval_horner, BarycentricWeights}; +pub use sequential_lagrange::SequentialLagrange; diff --git a/src/polynomial/sequential_lagrange.rs b/src/polynomial/sequential_lagrange.rs new file mode 100644 index 00000000..ab6fded7 --- /dev/null +++ b/src/polynomial/sequential_lagrange.rs @@ -0,0 +1,208 @@ +//! Sequential Lagrange polynomial over the Boolean hypercube. +//! +//! Maintains `eq(r, x) = Π_j (r_j · x_j + (1 − r_j) · (1 − x_j))` +//! incrementally as you iterate over `{0,1}^v`. +//! +//! Designed to compose with [`Ascending`](crate::hypercube::Ascending): +//! call [`advance_to`](SequentialLagrange::advance_to) with each successive +//! index. The XOR of consecutive indices tells us which bits flipped; +//! each flipped bit requires one multiply+divide to update the product. +//! +//! For ascending order, bit 0 flips every step, bit 1 every 2 steps, etc. +//! The amortized cost per step is O(1) (geometric series: 1 + 1/2 + 1/4 + ... = 2). + +extern crate alloc; +use crate::field::SumcheckField; +use alloc::vec::Vec; + +/// Sequential Lagrange polynomial `eq(r, ·)` with incremental updates. +/// +/// # Usage +/// +/// ```ignore +/// use effsc::polynomial::SequentialLagrange; +/// use effsc::hypercube::Ascending; +/// +/// let point = vec![r0, r1, r2]; +/// let mut lag = SequentialLagrange::new(&point); +/// +/// for p in Ascending::new(3) { +/// lag.advance_to(p.index); +/// let eq_val = lag.value(); +/// // use eq_val... +/// } +/// ``` +pub struct SequentialLagrange { + /// Precomputed factors: `factor_one[j] = r_j`, `factor_zero[j] = 1 − r_j`. + factor_one: Vec, + factor_zero: Vec, + /// Current product value `Π_j factor(x_j)`. + current_value: F, + /// Current hypercube index (which bits are set). + current_index: usize, + /// Number of variables. + num_vars: usize, +} + +impl SequentialLagrange { + /// Initialize at the origin (index 0): `eq(r, 0) = Π_j (1 − r_j)`. + pub fn new(point: &[F]) -> Self { + let num_vars = point.len(); + let factor_one: Vec = point.to_vec(); + let factor_zero: Vec = point.iter().map(|&r| F::ONE - r).collect(); + + // Initial value: eq(r, 0...0) = Π_j (1 - r_j) + let current_value = factor_zero.iter().copied().fold(F::ONE, |acc, f| acc * f); + + Self { + factor_one, + factor_zero, + current_value, + current_index: 0, + num_vars, + } + } + + /// Current value of `eq(r, x)` at the current hypercube point. + #[inline] + pub fn value(&self) -> F { + self.current_value + } + + /// Current hypercube index. + #[inline] + pub fn index(&self) -> usize { + self.current_index + } + + /// Advance to a new hypercube index. + /// + /// Updates the product by flipping the factors for each bit that + /// changed between the current index and `next_index`. + /// + /// For ascending iteration (0, 1, 2, ...), this is amortized O(1): + /// bit 0 flips every step (cost 1), bit 1 every 2 steps (cost 1/2), + /// bit 2 every 4 steps (cost 1/4), etc. Total: Σ 1/2^k = 2. + pub fn advance_to(&mut self, next_index: usize) { + let diff = self.current_index ^ next_index; + let mut bits = diff; + while bits != 0 { + let j = bits.trailing_zeros() as usize; + debug_assert!(j < self.num_vars); + + // Determine if bit j flipped 0→1 or 1→0. + let was_one = (self.current_index >> j) & 1 == 1; + if was_one { + // 1→0: replace factor_one[j] with factor_zero[j] + // new = old / factor_one[j] * factor_zero[j] + if let Some(inv) = self.factor_one[j].inverse() { + self.current_value *= self.factor_zero[j] * inv; + } + } else { + // 0→1: replace factor_zero[j] with factor_one[j] + // new = old / factor_zero[j] * factor_one[j] + if let Some(inv) = self.factor_zero[j].inverse() { + self.current_value *= self.factor_one[j] * inv; + } + } + + bits &= bits - 1; // clear lowest set bit + } + self.current_index = next_index; + } + + /// Reset to index 0. + pub fn reset(&mut self) { + self.current_value = self + .factor_zero + .iter() + .copied() + .fold(F::ONE, |acc, f| acc * f); + self.current_index = 0; + } +} + +#[cfg(test)] +#[cfg(feature = "arkworks")] +mod tests { + use super::*; + use crate::tests::F64; + + /// Compute eq(r, x) directly for reference. + fn eq_direct(point: &[F64], index: usize) -> F64 { + let num_vars = point.len(); + (0..num_vars).fold(F64::from(1u64), |acc, j| { + let bit = F64::from(((index >> j) & 1) as u64); + acc * (point[j] * bit + (F64::from(1u64) - point[j]) * (F64::from(1u64) - bit)) + }) + } + + #[test] + fn sequential_ascending_matches_direct() { + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(42); + let num_vars = 4; + let point: Vec = (0..num_vars).map(|_| F64::rand(&mut rng)).collect(); + + let mut lag = SequentialLagrange::new(&point); + assert_eq!(lag.value(), eq_direct(&point, 0)); + + for i in 1..(1 << num_vars) { + lag.advance_to(i); + let expected = eq_direct(&point, i); + assert_eq!(lag.value(), expected, "mismatch at index {i}"); + } + } + + #[test] + fn sequential_random_access() { + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(99); + let num_vars = 3; + let point: Vec = (0..num_vars).map(|_| F64::rand(&mut rng)).collect(); + + let mut lag = SequentialLagrange::new(&point); + + // Jump around non-sequentially. + for &idx in &[5, 3, 7, 0, 6, 1, 4, 2] { + lag.advance_to(idx); + assert_eq!(lag.value(), eq_direct(&point, idx), "at index {idx}"); + } + } + + #[test] + fn sequential_reset() { + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(7); + let point: Vec = (0..3).map(|_| F64::rand(&mut rng)).collect(); + + let mut lag = SequentialLagrange::new(&point); + lag.advance_to(5); + lag.reset(); + assert_eq!(lag.value(), eq_direct(&point, 0)); + assert_eq!(lag.index(), 0); + } + + #[test] + fn sequential_composes_with_ascending() { + use crate::hypercube::Ascending; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(123); + let num_vars = 5; + let point: Vec = (0..num_vars).map(|_| F64::rand(&mut rng)).collect(); + + let mut lag = SequentialLagrange::new(&point); + for p in Ascending::new(num_vars) { + lag.advance_to(p.index); + assert_eq!(lag.value(), eq_direct(&point, p.index)); + } + } +} diff --git a/src/proof.rs b/src/proof.rs new file mode 100644 index 00000000..6f89069d --- /dev/null +++ b/src/proof.rs @@ -0,0 +1,74 @@ +//! Sumcheck proof and error types. + +#![allow(unused_imports)] + +extern crate alloc; + +use crate::field::SumcheckField; +use alloc::vec::Vec; +use core::fmt; + +/// Output of the sumcheck protocol (Thaler Proposition 4.1). +/// +/// Contains the prover's round polynomials, the verifier's challenges, +/// and the prover's claimed final evaluation. The verifier reconstructs +/// consistency checks from this data; the final oracle check (verifying +/// `final_value == g(r_1, ..., r_v)`) is the caller's responsibility. +#[derive(Clone, Debug)] +pub struct SumcheckProof { + /// Round polynomial evaluations: `round_polys[j]` contains + /// `g_j(0), g_j(1), ..., g_j(degree)`. + pub round_polys: Vec>, + + /// Verifier challenges `r_1, ..., r_v`. + pub challenges: Vec, + + /// Prover's claimed value `g(r_1, ..., r_v)`. + pub final_value: F, +} + +/// Sumcheck verification error. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SumcheckError { + /// Round `j` consistency check failed: `g_j(0) + g_j(1) != claim`. + ConsistencyCheck { round: usize }, + /// Round polynomial has wrong degree. + DegreeMismatch { + round: usize, + expected: usize, + got: usize, + }, + /// Final evaluation mismatch. + FinalEvaluation, + /// Transcript error (e.g., malformed prover message). + TranscriptError { round: usize }, + /// Per-round hook failed (e.g., proof-of-work verification). + HookError { round: usize }, +} + +impl fmt::Display for SumcheckError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SumcheckError::ConsistencyCheck { round } => { + write!(f, "round {round}: consistency check failed") + } + SumcheckError::DegreeMismatch { + round, + expected, + got, + } => write!( + f, + "round {round}: degree mismatch: expected <= {expected}, got {got}" + ), + SumcheckError::FinalEvaluation => { + write!(f, "final evaluation mismatch") + } + SumcheckError::TranscriptError { round } => { + write!(f, "round {round}: transcript error") + } + SumcheckError::HookError { round } => { + write!(f, "round {round}: hook error") + } + } + } +} diff --git a/src/prover/core.rs b/src/prover/core.rs deleted file mode 100644 index fb438a93..00000000 --- a/src/prover/core.rs +++ /dev/null @@ -1,22 +0,0 @@ -use ark_ff::Field; - -use crate::streams::Stream; -pub trait ProverConfig> { - fn default(num_variables: usize, stream: S) -> Self; -} - -pub trait BatchProverConfig> { - fn default(num_variables: usize, streams: Vec) -> Self; -} - -pub trait ProductProverConfig> { - fn default(num_variables: usize, steams: Vec) -> Self; -} - -pub trait Prover { - type ProverConfig; - type ProverMessage; - type VerifierMessage; - fn new(prover_config: Self::ProverConfig) -> Self; - fn next_message(&mut self, verifier_message: Self::VerifierMessage) -> Self::ProverMessage; -} diff --git a/src/prover/mod.rs b/src/prover/mod.rs deleted file mode 100644 index 4cf9982a..00000000 --- a/src/prover/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod core; -pub use core::{BatchProverConfig, ProductProverConfig, Prover, ProverConfig}; diff --git a/src/provers/coefficient.rs b/src/provers/coefficient.rs new file mode 100644 index 00000000..b7c46e6e --- /dev/null +++ b/src/provers/coefficient.rs @@ -0,0 +1,392 @@ +//! MSB (half-split) coefficient sumcheck prover: arbitrary degree d. +//! +//! Folds the *most-significant* variable each round: pairs `(f[k], f[k+L/2])`. +//! This is optimal for in-memory and random-access-streaming workloads. +//! +//! For sequential streaming (Jolt-style), use +//! [`CoefficientProverLSB`](super::coefficient_lsb::CoefficientProverLSB). + +use ark_ff::Field; + +use crate::coefficient_sumcheck::RoundPolyEvaluator; +use crate::sumcheck_prover::SumcheckProver; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// MSB coefficient sumcheck prover (arbitrary degree d, half-split layout). +pub struct CoefficientProver<'a, F: Field, E: RoundPolyEvaluator> { + evaluator: &'a E, + tablewise: Vec>>, + pairwise: Vec>, + n_tw: usize, + n_pw: usize, + deg: usize, +} + +impl<'a, F: Field, E: RoundPolyEvaluator> CoefficientProver<'a, F, E> { + pub fn new(evaluator: &'a E, tablewise: Vec>>, pairwise: Vec>) -> Self { + let n_tw = tablewise.len(); + let n_pw = pairwise.len(); + let deg = evaluator.degree(); + Self { + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + deg, + } + } + + pub fn tablewise(&self) -> &[Vec>] { + &self.tablewise + } + + pub fn pairwise(&self) -> &[Vec] { + &self.pairwise + } + + fn half(&self) -> usize { + if self.n_tw > 0 { + self.tablewise[0].len() / 2 + } else if self.n_pw > 0 { + self.pairwise[0].len() / 2 + } else { + 0 + } + } + + /// Compute round polynomial coefficients using MSB (half-split) pairing. + fn evaluate_coefficients(&self) -> Vec { + let half = self.half(); + let n_coeffs = self.deg + 1; + + if self.evaluator.parallelize() && half > 0 { + msb_parallel_evaluate( + self.evaluator, + &self.tablewise, + &self.pairwise, + self.n_tw, + self.n_pw, + half, + n_coeffs, + ) + } else { + let mut coeffs = vec![F::ZERO; n_coeffs]; + msb_sequential_evaluate_into( + self.evaluator, + &self.tablewise, + &self.pairwise, + self.n_tw, + self.n_pw, + half, + &mut coeffs, + ); + coeffs + } + } + + /// MSB half-split reduce: `new[k] = v[k] + c*(v[k+half] - v[k])`. + fn reduce(&mut self, challenge: F) { + // Pairwise tables: MSB fold. + for table in self.pairwise.iter_mut() { + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + if crate::simd_sumcheck::dispatch::try_simd_reduce_msb(table, challenge) { + continue; + } + msb_fold_vec(table, challenge); + } + // Tablewise tables: MSB fold each row-vector. + for table in self.tablewise.iter_mut() { + msb_fold_tablewise(table, challenge); + } + } +} + +// ─── MSB fold helpers ────────────────────────────────────────────────────── + +/// In-place MSB fold for a flat vector: `new[k] = v[k] + c*(v[k+half] - v[k])`. +fn msb_fold_vec(v: &mut Vec, challenge: F) { + if v.len() <= 1 { + return; + } + let half = v.len() / 2; + for k in 0..half { + v[k] = v[k] + challenge * (v[k + half] - v[k]); + } + v.truncate(half); +} + +/// MSB fold for tablewise: each row-vector is folded by pairing +/// `(table[k], table[k+half])` and producing a new row. +fn msb_fold_tablewise(table: &mut Vec>, challenge: F) { + if table.len() <= 1 { + return; + } + let half = table.len() / 2; + for k in 0..half { + // Split to get non-overlapping mutable + shared references. + let (lo_part, hi_part) = table.split_at(half); + let new_row: Vec = lo_part[k] + .iter() + .zip(&hi_part[k]) + .map(|(&lo, &hi)| lo + challenge * (hi - lo)) + .collect(); + table[k] = new_row; + } + table.truncate(half); +} + +// ─── MSB evaluate helpers ────────────────────────────────────────────────── + +/// MSB pairing: pair index `k` with `k + half` (not `2k` with `2k+1`). +fn msb_sequential_evaluate_into( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + half: usize, + coeffs_out: &mut [F], +) { + for c in coeffs_out.iter_mut() { + *c = F::ZERO; + } + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + + for k in 0..half { + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[k], &table[k + half]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[k], table[k + half]); + } + evaluator.accumulate_pair(coeffs_out, &tw_buf[..n_tw], &pw_buf[..n_pw]); + } +} + +#[cfg(feature = "parallel")] +fn msb_parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + half: usize, + n_coeffs: usize, +) -> Vec { + (0..half) + .into_par_iter() + .fold_with(vec![F::ZERO; n_coeffs], |mut acc, k| { + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[k], &table[k + half]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[k], table[k + half]); + } + evaluator.accumulate_pair(&mut acc, &tw_buf[..n_tw], &pw_buf[..n_pw]); + acc + }) + .reduce_with(|mut a, b| { + for (ai, bi) in a.iter_mut().zip(&b) { + *ai += *bi; + } + a + }) + .unwrap_or_else(|| vec![F::ZERO; n_coeffs]) +} + +#[cfg(not(feature = "parallel"))] +fn msb_parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + half: usize, + n_coeffs: usize, +) -> Vec { + let mut coeffs = vec![F::ZERO; n_coeffs]; + msb_sequential_evaluate_into( + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + half, + &mut coeffs, + ); + coeffs +} + +// ─── Horner evaluation ───────────────────────────────────────────────────── + +#[inline] +fn eval_poly_at(coeffs: &[F], x: F) -> F { + if coeffs.is_empty() { + return F::ZERO; + } + let mut result = coeffs[coeffs.len() - 1]; + for i in (0..coeffs.len() - 1).rev() { + result = result * x + coeffs[i]; + } + result +} + +// ─── SumcheckProver impl ─────────────────────────────────────────────────── + +#[cfg(feature = "arkworks")] +impl<'a, F, E> SumcheckProver for CoefficientProver<'a, F, E> +where + F: ark_ff::Field, + E: RoundPolyEvaluator, +{ + fn degree(&self) -> usize { + self.deg + } + + fn round(&mut self, challenge: Option) -> Vec { + if let Some(c) = challenge { + self.reduce(c); + } + + let coeffs = self.evaluate_coefficients(); + + let mut evals = Vec::with_capacity(self.deg + 1); + for i in 0..=self.deg { + evals.push(eval_poly_at(&coeffs, F::from(i as u64))); + } + evals + } + + fn finalize(&mut self, last_challenge: F) { + self.reduce(last_challenge); + } + + fn final_value(&self) -> F { + // After full reduction, each pairwise table has 1 element and + // each tablewise table has 1 row. The final value is the + // evaluation of the user's polynomial at this single point. + // + // For degree-1 single-pairwise (multilinear shape): just pairwise[0][0]. + // General case: evaluate the round polynomial at the single "pair" + // (which is really just one element — lo with implicit zero hi). + if self.n_pw == 1 && self.n_tw == 0 && self.pairwise[0].len() == 1 { + return self.pairwise[0][0]; + } + + // General: build the evaluator's output from the singleton tables. + let mut coeffs = vec![F::ZERO; self.deg + 1]; + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + for (i, table) in self.tablewise.iter().enumerate() { + if table.len() == 1 { + tw_buf[i] = (&table[0], &[]); + } + } + for (i, table) in self.pairwise.iter().enumerate() { + if table.len() == 1 { + pw_buf[i] = (table[0], F::ZERO); + } + } + self.evaluator + .accumulate_pair(&mut coeffs, &tw_buf[..self.n_tw], &pw_buf[..self.n_pw]); + eval_poly_at(&coeffs, F::ZERO) + eval_poly_at(&coeffs, F::ONE) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + struct Degree1Eval; + impl RoundPolyEvaluator for Degree1Eval { + fn degree(&self) -> usize { + 1 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (lo, hi) = pw[0]; + coeffs[0] += lo; + coeffs[1] += hi - lo; + } + } + + /// MSB CoefficientProver produces valid round polynomials. + #[test] + fn msb_degree1_consistency() { + let mut rng = StdRng::seed_from_u64(42); + let n = 1 << 4; + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let evaluator = Degree1Eval; + let mut prover = CoefficientProver::new(&evaluator, vec![], vec![evals]); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, 4, &mut t, |_, _| {}); + + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + claimed_sum + ); + + let mut claim = claimed_sum; + for (rp, &r) in proof.round_polys.iter().zip(&proof.challenges) { + assert_eq!(rp[0] + rp[1], claim); + claim = rp[0] + r * (rp[1] - rp[0]); + } + assert_eq!(proof.final_value, claim); + } + + /// MSB CoefficientProver matches MSB MultilinearProver for degree 1. + #[test] + fn msb_matches_multilinear_prover() { + use crate::provers::multilinear::MultilinearProver; + + let mut rng = StdRng::seed_from_u64(42); + let n = 1 << 4; + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + // MultilinearProver (MSB). + let mut ml = MultilinearProver::new(evals.clone()); + let mut trng = StdRng::seed_from_u64(99); + let mut t1 = SanityTranscript::new(&mut trng); + let ml_proof = sumcheck(&mut ml, 4, &mut t1, |_, _| {}); + + // CoefficientProver (MSB) with degree-1 evaluator. + let evaluator = Degree1Eval; + let mut cp = CoefficientProver::new(&evaluator, vec![], vec![evals]); + let mut trng2 = StdRng::seed_from_u64(99); + let mut t2 = SanityTranscript::new(&mut trng2); + let cp_proof = sumcheck(&mut cp, 4, &mut t2, |_, _| {}); + + // Same challenges (same transcript seed). + assert_eq!(ml_proof.challenges, cp_proof.challenges); + + // Same round polynomial evaluations. + for (i, (ml_rp, cp_rp)) in ml_proof + .round_polys + .iter() + .zip(&cp_proof.round_polys) + .enumerate() + { + assert_eq!(ml_rp[0], cp_rp[0], "round {i}: q(0) mismatch"); + assert_eq!(ml_rp[1], cp_rp[1], "round {i}: q(1) mismatch"); + } + } +} diff --git a/src/provers/coefficient_lsb.rs b/src/provers/coefficient_lsb.rs new file mode 100644 index 00000000..13b7a556 --- /dev/null +++ b/src/provers/coefficient_lsb.rs @@ -0,0 +1,508 @@ +//! LSB (pair-split) coefficient sumcheck prover: arbitrary degree d. +//! +//! Folds the *least-significant* variable each round: pairs `(f[2k], f[2k+1])`. +//! This is the natural layout for **sequential streaming** where evaluations +//! arrive in index order. +//! +//! Use this prover for Jolt-style workloads. For in-memory or random-access +//! workloads, prefer [`CoefficientProver`](super::coefficient::CoefficientProver) +//! (MSB layout). + +use ark_ff::Field; + +use crate::coefficient_sumcheck::RoundPolyEvaluator; +use crate::reductions::{pairwise, tablewise}; +use crate::sumcheck_prover::SumcheckProver; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Coefficient sumcheck prover (arbitrary degree d). +/// +/// Wraps tablewise and pairwise evaluation tables with a user-provided +/// [`RoundPolyEvaluator`] that defines the round polynomial shape. +/// +/// # Construction +/// +/// ```ignore +/// let evaluator = MyEvaluator; +/// let mut prover = CoefficientProverLSB::new( +/// &evaluator, +/// tablewise_tables, // Vec>> +/// pairwise_tables, // Vec> +/// num_rounds, +/// ); +/// let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); +/// ``` +pub struct CoefficientProverLSB<'a, F: Field, E: RoundPolyEvaluator> { + evaluator: &'a E, + tablewise: Vec>>, + pairwise: Vec>, + n_tw: usize, + n_pw: usize, + deg: usize, + /// Cached degree-1 SIMD evaluation from fused reduce+evaluate. + pending_degree1_eval: Option>, + /// Whether this is the degree-1, single-pairwise, no-tablewise fast path. + is_degree1_simd_path: bool, +} + +impl<'a, F: Field, E: RoundPolyEvaluator> CoefficientProverLSB<'a, F, E> { + pub fn new(evaluator: &'a E, tablewise: Vec>>, pairwise: Vec>) -> Self { + let n_tw = tablewise.len(); + let n_pw = pairwise.len(); + let deg = evaluator.degree(); + let is_degree1_simd_path = deg == 1 && n_pw == 1 && n_tw == 0; + Self { + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + deg, + pending_degree1_eval: None, + is_degree1_simd_path, + } + } + + /// Access the (possibly reduced) tablewise tables. + pub fn tablewise(&self) -> &[Vec>] { + &self.tablewise + } + + /// Access the (possibly reduced) pairwise tables. + pub fn pairwise(&self) -> &[Vec] { + &self.pairwise + } + + fn n_pairs(&self) -> usize { + if self.n_tw > 0 { + self.tablewise[0].len() / 2 + } else if self.n_pw > 0 { + self.pairwise[0].len() / 2 + } else { + 0 + } + } + + /// Compute round polynomial coefficients. + fn evaluate_coefficients(&mut self) -> Vec { + if let Some(cached) = self.pending_degree1_eval.take() { + return cached; + } + if self.is_degree1_simd_path { + return simd_evaluate_degree1(&self.pairwise[0]); + } + + let n_pairs = self.n_pairs(); + let n_coeffs = self.deg + 1; + + if self.evaluator.parallelize() { + parallel_evaluate( + self.evaluator, + &self.tablewise, + &self.pairwise, + self.n_tw, + self.n_pw, + n_pairs, + n_coeffs, + ) + } else { + let mut coeffs = vec![F::ZERO; n_coeffs]; + sequential_evaluate_into( + self.evaluator, + &self.tablewise, + &self.pairwise, + self.n_tw, + self.n_pw, + n_pairs, + &mut coeffs, + ); + coeffs + } + } + + /// Reduce all tables by folding with the challenge. + fn reduce(&mut self, challenge: F, is_last_round: bool) { + for table in self.tablewise.iter_mut() { + tablewise::reduce_evaluations(table, challenge); + } + + if self.is_degree1_simd_path && !is_last_round { + if let Some(next) = try_simd_fused_reduce_evaluate(&mut self.pairwise[0], challenge) { + self.pending_degree1_eval = Some(next); + return; + } + } + + for table in self.pairwise.iter_mut() { + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + if crate::simd_sumcheck::dispatch::try_simd_reduce(table, challenge) { + continue; + } + pairwise::reduce_evaluations(table, challenge); + } + } +} + +// ─── SumcheckProver impl ─────────────────────────────────────────────────── + +#[cfg(feature = "arkworks")] +impl<'a, F, E> SumcheckProver for CoefficientProverLSB<'a, F, E> +where + F: ark_ff::Field, + E: RoundPolyEvaluator, +{ + fn degree(&self) -> usize { + self.deg + } + + fn round(&mut self, challenge: Option) -> Vec { + // Reduce with previous challenge (if any). + if let Some(c) = challenge { + self.reduce(c, false); + } + + // Compute coefficient representation. + let coeffs = self.evaluate_coefficients(); + + // Convert coefficients → evaluations at {0, 1, ..., degree}. + let mut evals = Vec::with_capacity(self.deg + 1); + for i in 0..=self.deg { + evals.push(eval_poly_at(&coeffs, F::from(i as u64))); + } + evals + } + + fn finalize(&mut self, last_challenge: F) { + self.reduce(last_challenge, true); + } + + fn final_value(&self) -> F { + // After all rounds, each pairwise table should have 1 element, + // each tablewise table should have 1 row. The final value is + // the evaluator applied to these singletons. + let mut coeffs = vec![F::ZERO; self.deg + 1]; + let n_pairs = self.n_pairs(); + if n_pairs > 0 { + sequential_evaluate_into( + self.evaluator, + &self.tablewise, + &self.pairwise, + self.n_tw, + self.n_pw, + n_pairs, + &mut coeffs, + ); + } + // final_value = h(0) + h(1) = sum of evaluations at 0 and 1 + // Actually: the "final value" for coefficient sumcheck is the + // claimed sum at the last point, which is eval of the polynomial + // at the last challenge. But after finalize(), the tables are + // fully reduced — there's only 1 "pair" left (of size 1). + // The claim is h(0) + h(1) from the last round's perspective, + // but that's the *next* claim, not the evaluation. + // + // For consistency with the other provers: final_value should be + // the polynomial evaluated at the random point. After full + // reduction, the single remaining element in pairwise[0] IS + // the evaluation (for degree-1 single-pairwise case). + if self.is_degree1_simd_path && !self.pairwise.is_empty() && self.pairwise[0].len() == 1 { + return self.pairwise[0][0]; + } + // General case: sum the contributions. + eval_poly_at(&coeffs, F::ZERO) + eval_poly_at(&coeffs, F::ONE) + } +} + +// ─── Horner evaluation ───────────────────────────────────────────────────── + +/// Evaluate polynomial with coefficients `coeffs` at point `x` via Horner's method. +#[inline] +fn eval_poly_at(coeffs: &[F], x: F) -> F { + if coeffs.is_empty() { + return F::ZERO; + } + let mut result = coeffs[coeffs.len() - 1]; + for i in (0..coeffs.len() - 1).rev() { + result = result * x + coeffs[i]; + } + result +} + +// ─── Evaluate strategies (same as coefficient_sumcheck.rs) ───────────────── + +fn simd_evaluate_degree1(pw: &[F]) -> Vec { + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + { + if let Some(coeffs) = crate::simd_sumcheck::dispatch::try_simd_evaluate_degree1(pw) { + return coeffs; + } + } + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for chunk in pw.chunks_exact(2) { + s0 += chunk[0]; + s1 += chunk[1]; + } + vec![s0, s1 - s0] +} + +fn try_simd_fused_reduce_evaluate(pw: &mut Vec, challenge: F) -> Option> { + #[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + ))] + { + crate::simd_sumcheck::dispatch::try_simd_fused_reduce_evaluate_degree1(pw, challenge) + } + #[cfg(not(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) + )))] + { + let _ = (pw, challenge); + None + } +} + +#[cfg(feature = "parallel")] +fn parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + n_coeffs: usize, +) -> Vec { + (0..n_pairs) + .into_par_iter() + .fold_with(vec![F::ZERO; n_coeffs], |mut acc, pair_idx| { + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[2 * pair_idx], &table[2 * pair_idx + 1]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[2 * pair_idx], table[2 * pair_idx + 1]); + } + evaluator.accumulate_pair(&mut acc, &tw_buf[..n_tw], &pw_buf[..n_pw]); + acc + }) + .reduce_with(|mut a, b| { + for (ai, bi) in a.iter_mut().zip(&b) { + *ai += *bi; + } + a + }) + .unwrap_or_else(|| vec![F::ZERO; n_coeffs]) +} + +#[cfg(not(feature = "parallel"))] +fn parallel_evaluate( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + n_coeffs: usize, +) -> Vec { + let mut coeffs = vec![F::ZERO; n_coeffs]; + sequential_evaluate_into( + evaluator, + tablewise, + pairwise, + n_tw, + n_pw, + n_pairs, + &mut coeffs, + ); + coeffs +} + +fn sequential_evaluate_into( + evaluator: &impl RoundPolyEvaluator, + tablewise: &[Vec>], + pairwise: &[Vec], + n_tw: usize, + n_pw: usize, + n_pairs: usize, + coeffs_out: &mut [F], +) { + for c in coeffs_out.iter_mut() { + *c = F::ZERO; + } + let mut tw_buf: [(&[F], &[F]); 16] = [(&[], &[]); 16]; + let mut pw_buf: [(F, F); 16] = [(F::ZERO, F::ZERO); 16]; + + for pair_idx in 0..n_pairs { + for (i, table) in tablewise.iter().enumerate() { + tw_buf[i] = (&table[2 * pair_idx], &table[2 * pair_idx + 1]); + } + for (i, table) in pairwise.iter().enumerate() { + pw_buf[i] = (table[2 * pair_idx], table[2 * pair_idx + 1]); + } + evaluator.accumulate_pair(coeffs_out, &tw_buf[..n_tw], &pw_buf[..n_pw]); + } +} + +// ─── Tests ───────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + struct Degree1Eval; + impl RoundPolyEvaluator for Degree1Eval { + fn degree(&self) -> usize { + 1 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (even, odd) = pw[0]; + coeffs[0] += even; + coeffs[1] += odd - even; + } + } + + struct Degree2Eval; + impl RoundPolyEvaluator for Degree2Eval { + fn degree(&self) -> usize { + 2 + } + fn accumulate_pair(&self, coeffs: &mut [F64], _tw: &[(&[F64], &[F64])], pw: &[(F64, F64)]) { + let (s0, s1) = pw[0]; + let s2 = s0 + s1; + coeffs[0] += s0; + coeffs[1] += (-F64::from(3u64) * s0 + F64::from(4u64) * s1 - s2) / F64::from(2u64); + coeffs[2] += (s0 - F64::from(2u64) * s1 + s2) / F64::from(2u64); + } + } + + /// CoefficientProverLSB produces valid sumcheck proofs for degree 1. + #[test] + fn degree1_consistency() { + let mut rng = StdRng::seed_from_u64(42); + let n = 1 << 4; + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let evaluator = Degree1Eval; + let pairwise = vec![evals]; + let tablewise: Vec>> = vec![]; + + let mut prover = CoefficientProverLSB::new(&evaluator, tablewise, pairwise); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, 4, &mut t, |_, _| {}); + + // Round-0 consistency. + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + claimed_sum + ); + + // All-round consistency via Lagrange. + let mut claim = claimed_sum; + for (rp, &r) in proof.round_polys.iter().zip(&proof.challenges) { + assert_eq!(rp[0] + rp[1], claim, "consistency check failed"); + // degree 1: q(r) = q(0) + r*(q(1) - q(0)) + claim = rp[0] + r * (rp[1] - rp[0]); + } + assert_eq!(proof.final_value, claim); + } + + /// CoefficientProverLSB matches the old coefficient_sumcheck function. + #[test] + fn matches_legacy_coefficient_sumcheck() { + let mut rng = StdRng::seed_from_u64(42); + let n = 1 << 4; + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + // Old API. + let mut old_pw = vec![evals.clone()]; + let mut old_tw: Vec>> = vec![]; + let mut trng = StdRng::seed_from_u64(99); + let mut t_old = SanityTranscript::new(&mut trng); + let old_result = crate::coefficient_sumcheck::coefficient_sumcheck( + &Degree1Eval, + &mut old_tw, + &mut old_pw, + 4, + &mut t_old, + ); + + // New API. + let evaluator = Degree1Eval; + let mut prover = CoefficientProverLSB::new(&evaluator, vec![], vec![evals]); + let mut trng2 = StdRng::seed_from_u64(99); + let mut t_new = SanityTranscript::new(&mut trng2); + let new_result = sumcheck(&mut prover, 4, &mut t_new, |_, _| {}); + + // Same challenges (same transcript seed). + assert_eq!(old_result.verifier_messages, new_result.challenges); + + // Round polynomials: old is coefficients, new is evaluations. + // Verify consistency: old coeffs evaluated at {0,1} should match + // new evals[0] and evals[1]. + for (i, (old_poly, new_evals)) in old_result + .prover_messages + .iter() + .zip(&new_result.round_polys) + .enumerate() + { + use ark_poly::Polynomial; + let old_at_0 = old_poly.evaluate(&F64::from(0u64)); + let old_at_1 = old_poly.evaluate(&F64::from(1u64)); + assert_eq!(old_at_0, new_evals[0], "round {i}: q(0) mismatch"); + assert_eq!(old_at_1, new_evals[1], "round {i}: q(1) mismatch"); + } + } + + /// Degree-2 CoefficientProverLSB round polynomial has 3 evaluations. + #[test] + fn degree2_structure() { + let mut rng = StdRng::seed_from_u64(42); + let n = 1 << 3; + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let evaluator = Degree2Eval; + let mut prover = CoefficientProverLSB::new(&evaluator, vec![], vec![evals]); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, 3, &mut t, |_, _| {}); + + for rp in &proof.round_polys { + assert_eq!(rp.len(), 3, "degree-2 should have 3 evaluations"); + } + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + claimed_sum + ); + } +} diff --git a/src/provers/gkr.rs b/src/provers/gkr.rs new file mode 100644 index 00000000..63f95c41 --- /dev/null +++ b/src/provers/gkr.rs @@ -0,0 +1,352 @@ +//! GKR round sumcheck prover (degree 2). +//! +//! Implements [`SumcheckProver`] for the GKR round polynomial: +//! +//! ```text +//! f_r(b, c) = add_i(r, b, c) · (W(b) + W(c)) + mult_i(r, b, c) · (W(b) · W(c)) +//! ``` +//! +//! where `add_i` and `mult_i` are gate predicates (partially evaluated at the +//! previous layer's random point `r`), and `W` is the witness for the next +//! layer. +//! +//! # Construction +//! +//! The prover takes three evaluation tables: +//! - `add_evals`: `add_i(r, b, c)` over `{0,1}^{2k}`, `2^{2k}` entries +//! - `mult_evals`: `mult_i(r, b, c)` over `{0,1}^{2k}`, `2^{2k}` entries +//! - `w_evals`: `W(x)` over `{0,1}^k`, `2^k` entries +//! +//! The sumcheck runs over `2k` variables. After all rounds and finalization, +//! [`claimed_w_values()`](GkrProver::claimed_w_values) returns `(W(b*), W(c*))` +//! for the reduce-to-one sub-protocol. +//! +//! # Example +//! +//! ```ignore +//! let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); +//! let proof = sumcheck(&mut prover, 2 * k, &mut transcript, noop_hook); +//! let (w_b, w_c) = prover.claimed_w_values(); +//! ``` + +use crate::field::SumcheckField; +use crate::inner_product_sumcheck as ip; +use crate::sumcheck_prover::SumcheckProver; + +extern crate alloc; +use alloc::vec; +use alloc::vec::Vec; + +/// GKR round sumcheck prover (degree 2). +/// +/// See [module docs](self) for details. +pub struct GkrProver { + /// Gate add predicate: `add_i(r, b, c)`, `2^{2k}` entries. + add_evals: Vec, + /// Gate mult predicate: `mult_i(r, b, c)`, `2^{2k}` entries. + mult_evals: Vec, + /// Witness `W(b)` broadcast over c: `w_b[b * 2^k + c] = W(b)`. + w_b: Vec, + /// Witness `W(c)` broadcast over b: `w_c[b * 2^k + c] = W(c)`. + w_c: Vec, +} + +impl GkrProver { + /// Construct from gate predicates and witness evaluations. + /// + /// - `add_evals`: `add_i(r, b, c)` for all `(b, c) in {0,1}^{2k}`. + /// - `mult_evals`: `mult_i(r, b, c)` for all `(b, c) in {0,1}^{2k}`. + /// - `w_evals`: `W(x)` for all `x in {0,1}^k`. + /// + /// The gate tables must have length `w_evals.len()^2`. + pub fn new(add_evals: Vec, mult_evals: Vec, w_evals: Vec) -> Self { + let n = w_evals.len(); + let n_bc = n * n; + assert_eq!(add_evals.len(), n_bc, "add_evals must have len w^2"); + assert_eq!(mult_evals.len(), n_bc, "mult_evals must have len w^2"); + + // Expand witness into broadcast tables over the (b, c) hypercube. + let mut w_b = vec![F::ZERO; n_bc]; + let mut w_c = vec![F::ZERO; n_bc]; + for b in 0..n { + for c in 0..n { + let idx = b * n + c; + w_b[idx] = w_evals[b]; + w_c[idx] = w_evals[c]; + } + } + + Self { + add_evals, + mult_evals, + w_b, + w_c, + } + } + + /// After full sumcheck: the claimed witness evaluations `(W(b*), W(c*))`. + /// + /// These are the inputs to the reduce-to-one sub-protocol (Thaler §4.5.2). + pub fn claimed_w_values(&self) -> (F, F) { + if self.w_b.len() == 1 { + (self.w_b[0], self.w_c[0]) + } else { + (F::ZERO, F::ZERO) + } + } +} + +#[cfg(feature = "arkworks")] +impl SumcheckProver for GkrProver +where + F: ark_ff::Field, +{ + fn degree(&self) -> usize { + 2 + } + + fn round(&mut self, challenge: Option) -> Vec { + // Fold all four tables with the previous challenge. + if let Some(w) = challenge { + ip::fold(&mut self.add_evals, w); + ip::fold(&mut self.mult_evals, w); + ip::fold(&mut self.w_b, w); + ip::fold(&mut self.w_c, w); + } + + let n = self.add_evals.len(); + if n <= 1 { + let v = if n == 1 { + let wb = self.w_b[0]; + let wc = self.w_c[0]; + self.add_evals[0] * (wb + wc) + self.mult_evals[0] * (wb * wc) + } else { + F::ZERO + }; + return vec![v, F::ZERO, F::ZERO]; + } + + let half = n.next_power_of_two() >> 1; + let (add_lo, add_hi) = self.add_evals.split_at(half); + let (mult_lo, mult_hi) = self.mult_evals.split_at(half); + let (wb_lo, wb_hi) = self.w_b.split_at(half); + let (wc_lo, wc_hi) = self.w_c.split_at(half); + + let paired = add_hi.len(); + + let mut q0 = F::ZERO; + let mut q1 = F::ZERO; + let mut q2 = F::ZERO; + + for i in 0..paired { + let al = add_lo[i]; + let ah = add_hi[i]; + let ml = mult_lo[i]; + let mh = mult_hi[i]; + let wbl = wb_lo[i]; + let wbh = wb_hi[i]; + let wcl = wc_lo[i]; + let wch = wc_hi[i]; + + // q(0): all factors at t=0 (low half) + q0 += al * (wbl + wcl) + ml * (wbl * wcl); + + // q(1): all factors at t=1 (high half) + q1 += ah * (wbh + wch) + mh * (wbh * wch); + + // q(2): linear extension to t=2: val_2 = 2*hi - lo + let a2 = ah + ah - al; + let m2 = mh + mh - ml; + let wb2 = wbh + wbh - wbl; + let wc2 = wch + wch - wcl; + q2 += a2 * (wb2 + wc2) + m2 * (wb2 * wc2); + } + + // Tail: hi is implicitly zero, so at t=2 each factor is -lo. + // add term: (-al)*(-wbl + -wcl) = al*(wbl + wcl) [even number of negations] + // mult term: (-ml)*(-wbl)*(-wcl) = -ml*wbl*wcl [odd number of negations] + for i in paired..half.min(n) { + let al = add_lo[i]; + let ml = mult_lo[i]; + let wbl = wb_lo[i]; + let wcl = wc_lo[i]; + + q0 += al * (wbl + wcl) + ml * (wbl * wcl); + // q(1) += 0 (hi is zero) + q2 += al * (wbl + wcl) - ml * (wbl * wcl); + } + + vec![q0, q1, q2] + } + + fn finalize(&mut self, last_challenge: F) { + ip::fold(&mut self.add_evals, last_challenge); + ip::fold(&mut self.mult_evals, last_challenge); + ip::fold(&mut self.w_b, last_challenge); + ip::fold(&mut self.w_c, last_challenge); + } + + fn final_value(&self) -> F { + if self.add_evals.len() == 1 { + let wb = self.w_b[0]; + let wc = self.w_c[0]; + self.add_evals[0] * (wb + wc) + self.mult_evals[0] * (wb * wc) + } else { + F::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::polynomial::eval_from_evals; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + /// Run GKR prover and verify the proof for a given k. + fn run_gkr_test(k: usize, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let n = 1 << k; + let n_bc = n * n; + + let add_evals: Vec = (0..n_bc).map(|_| F64::rand(&mut rng)).collect(); + let mult_evals: Vec = (0..n_bc).map(|_| F64::rand(&mut rng)).collect(); + let w_evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + // Compute claimed sum directly. + let mut expected_sum = F64::ZERO; + for b in 0..n { + for c in 0..n { + let idx = b * n + c; + let wb = w_evals[b]; + let wc = w_evals[c]; + expected_sum += add_evals[idx] * (wb + wc) + mult_evals[idx] * (wb * wc); + } + } + + // Run the prover. + let mut prover = GkrProver::new(add_evals.clone(), mult_evals.clone(), w_evals.clone()); + let num_rounds = 2 * k; + let mut trng = StdRng::seed_from_u64(99); + let mut transcript = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); + + // Verify: q(0) + q(1) = claim each round, update via Lagrange. + let mut claim = expected_sum; + for (round, evals) in proof.round_polys.iter().enumerate() { + assert_eq!( + evals[0] + evals[1], + claim, + "k={k}: round {round}: q(0) + q(1) != claim" + ); + claim = eval_from_evals(evals, proof.challenges[round]); + } + assert_eq!(claim, proof.final_value, "k={k}: final claim mismatch"); + + // Verify claimed W values match multilinear extension. + let (w_b_star, w_c_star) = prover.claimed_w_values(); + assert_eq!( + w_b_star, + eval_mle(&w_evals, &proof.challenges[..k]), + "k={k}: W(b*) mismatch" + ); + assert_eq!( + w_c_star, + eval_mle(&w_evals, &proof.challenges[k..]), + "k={k}: W(c*) mismatch" + ); + } + + /// Evaluate the multilinear extension of `evals` at point `r`. + fn eval_mle(evals: &[F64], r: &[F64]) -> F64 { + let mut table = evals.to_vec(); + for &ri in r { + let half = table.len() / 2; + for j in 0..half { + table[j] = table[j] + ri * (table[j + half] - table[j]); + } + table.truncate(half); + } + table[0] + } + + #[test] + fn gkr_k1() { + run_gkr_test(1, 0x100); + } + + #[test] + fn gkr_k2() { + run_gkr_test(2, 0x200); + } + + #[test] + fn gkr_k3() { + run_gkr_test(3, 0x300); + } + + #[test] + fn gkr_k4() { + run_gkr_test(4, 0x400); + } + + /// All-zero gates: sum should be zero regardless of witness. + #[test] + fn gkr_zero_gates() { + let k = 2; + let n = 1 << k; + let n_bc = n * n; + let mut rng = StdRng::seed_from_u64(0x500); + + let add_evals = vec![F64::ZERO; n_bc]; + let mult_evals = vec![F64::ZERO; n_bc]; + let w_evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); + let mut trng = StdRng::seed_from_u64(99); + let mut transcript = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, 2 * k, &mut transcript, |_, _| {}); + + // All round polynomials should be zero. + for evals in &proof.round_polys { + for &v in evals { + assert_eq!(v, F64::ZERO); + } + } + } + + /// Add-only circuit (no mult gates): degree is still 2 but + /// mult contributions are zero. + #[test] + fn gkr_add_only() { + let k = 2; + let n = 1 << k; + let n_bc = n * n; + let mut rng = StdRng::seed_from_u64(0x600); + + let add_evals: Vec = (0..n_bc).map(|_| F64::rand(&mut rng)).collect(); + let mult_evals = vec![F64::ZERO; n_bc]; + let w_evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + let mut expected_sum = F64::ZERO; + for b in 0..n { + for c in 0..n { + expected_sum += add_evals[b * n + c] * (w_evals[b] + w_evals[c]); + } + } + + let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); + let mut trng = StdRng::seed_from_u64(99); + let mut transcript = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, 2 * k, &mut transcript, |_, _| {}); + + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + expected_sum, + ); + } +} diff --git a/src/provers/inner_product.rs b/src/provers/inner_product.rs new file mode 100644 index 00000000..2c62ba33 --- /dev/null +++ b/src/provers/inner_product.rs @@ -0,0 +1,218 @@ +//! Inner-product sumcheck prover: `g = f_tilde * g_tilde`, degree 2. +//! +//! Implements [`SumcheckProver`] for the quadratic sumcheck `∑_x f(x)·g(x)`. + +use crate::field::SumcheckField; +use crate::inner_product_sumcheck as ip; +use crate::sumcheck_prover::SumcheckProver; + +/// Inner-product sumcheck prover (degree 2). +/// +/// Computes `∑_x f(x)·g(x)` where `f` and `g` are multilinear polynomials +/// specified by their evaluations over the Boolean hypercube. +/// +/// Wire format: evaluations `[q(0), q(1), q(2)]` where `q` is the degree-2 +/// round polynomial. +/// +/// # Construction +/// +/// ```ignore +/// let mut prover = InnerProductProver::new(a, b); +/// let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); +/// let (f_eval, g_eval) = prover.final_evaluations(); +/// ``` +pub struct InnerProductProver { + a: Vec, + b: Vec, +} + +impl InnerProductProver { + /// Time strategy prover: holds both evaluation vectors in memory. + pub fn new(a: Vec, b: Vec) -> Self { + assert_eq!(a.len(), b.len(), "a and b must have equal length"); + Self { a, b } + } + + /// Access the (possibly folded) evaluation vectors. + pub fn evaluations(&self) -> (&[F], &[F]) { + (&self.a, &self.b) + } + + /// After full sumcheck: the final evaluations `(f(r), g(r))`. + pub fn final_evaluations(&self) -> (F, F) { + if self.a.len() == 1 { + (self.a[0], self.b[0]) + } else { + (F::ZERO, F::ZERO) + } + } +} + +// NOTE: The `ark_ff::Field` bound is temporary — required because the +// underlying functions in `inner_product_sumcheck.rs` use `F: Field`. +// It will be removed when those functions are ported to `SumcheckField`. +#[cfg(feature = "arkworks")] +impl SumcheckProver for InnerProductProver +where + F: ark_ff::Field, +{ + fn degree(&self) -> usize { + 2 + } + + fn round(&mut self, challenge: Option) -> Vec { + // Fold with previous challenge (if any). + if let Some(w) = challenge { + ip::fold(&mut self.a, w); + ip::fold(&mut self.b, w); + } + + // Compute round polynomial evaluations at {0, 1, 2}. + // + // The round polynomial q(X) = Σ_{x'} f(r, X, x') · g(r, X, x'). + // With MSB half-split: + // f(r, X, x') = (1−X)·a_lo[x'] + X·a_hi[x'] + // + // q(0) = dot(a_lo, b_lo) + // q(1) = dot(a_hi, b_hi) + // q(2) = dot(2·a_hi − a_lo, 2·b_hi − b_lo) + let n = self.a.len(); + if n <= 1 { + let v = if n == 1 { + self.a[0] * self.b[0] + } else { + F::ZERO + }; + return vec![v, F::ZERO, F::ZERO]; + } + + let half = n.next_power_of_two() >> 1; + let (a_lo, a_hi) = self.a.split_at(half); + let (b_lo, b_hi) = self.b.split_at(half); + + // a_lo may be longer than a_hi (non-pow2 input, implicit zero padding). + let paired = a_hi.len(); + let a_lo_paired = &a_lo[..paired]; + let b_lo_paired = &b_lo[..paired]; + let a_lo_tail = &a_lo[paired..]; + let b_lo_tail = &b_lo[paired..]; + + let mut q0 = F::ZERO; + let mut q1 = F::ZERO; + let mut q2 = F::ZERO; + for i in 0..paired { + let al = a_lo_paired[i]; + let ah = a_hi[i]; + let bl = b_lo_paired[i]; + let bh = b_hi[i]; + q0 += al * bl; + q1 += ah * bh; + // f(2) = 2·ah − al, g(2) = 2·bh − bl + let a2 = ah + ah - al; + let b2 = bh + bh - bl; + q2 += a2 * b2; + } + + // Tail (hi is implicitly zero): contributes to q0 only. + // q(0) += dot(tail_a, tail_b), q(1) += 0, q(2) += dot(-tail_a, -tail_b) = dot(tail_a, tail_b). + let tail_dot: F = a_lo_tail.iter().zip(b_lo_tail).map(|(&a, &b)| a * b).sum(); + q0 += tail_dot; + q2 += tail_dot; + + vec![q0, q1, q2] + } + + fn finalize(&mut self, last_challenge: F) { + ip::fold(&mut self.a, last_challenge); + ip::fold(&mut self.b, last_challenge); + } + + fn final_value(&self) -> F { + if self.a.len() == 1 { + self.a[0] * self.b[0] + } else { + F::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + /// New `InnerProductProver` matches the old `inner_product_sumcheck`. + /// + /// The old API sends `(c0, c2)` (difference form); the new API sends + /// `[q(0), q(1), q(2)]` (evaluation form). We verify the underlying + /// values are consistent and the final evaluation matches. + #[test] + fn matches_legacy_inner_product_sumcheck() { + let mut rng = StdRng::seed_from_u64(42); + let a: Vec = (0..16).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..16).map(|_| F64::rand(&mut rng)).collect(); + + // Old API. + let mut old_a = a.clone(); + let mut old_b = b.clone(); + let mut trng = StdRng::seed_from_u64(99); + let mut t_old = SanityTranscript::new(&mut trng); + let old_result = crate::inner_product_sumcheck::inner_product_sumcheck( + &mut old_a, + &mut old_b, + &mut t_old, + |_, _| {}, + ); + + // New API. + let mut prover = InnerProductProver::new(a, b); + let num_rounds = 4; // log2(16) + let mut trng2 = StdRng::seed_from_u64(99); + let mut t_new = SanityTranscript::new(&mut trng2); + let new_result = sumcheck(&mut prover, num_rounds, &mut t_new, |_, _| {}); + + // Compare round-by-round consistency. + assert_eq!( + old_result.prover_messages.len(), + new_result.round_polys.len() + ); + for (i, (old_msg, new_evals)) in old_result + .prover_messages + .iter() + .zip(&new_result.round_polys) + .enumerate() + { + let (c0, c2) = *old_msg; + // Old API: c0 = q(0), and q(0) + q(1) = claim. + // New API: [q(0), q(1), q(2)]. + assert_eq!(c0, new_evals[0], "round {i}: q(0) mismatch"); + // c2 from old = x² coefficient. Verify via: + // q(X) = c0 + c1·X + c2·X² + // q(2) = c0 + 2·c1 + 4·c2 + // c1 = q(1) - c0 - c2 + let q1 = new_evals[1]; + let c1_derived = q1 - c0 - c2; + let q2_expected = c0 + c1_derived.double() + c2.double().double(); + assert_eq!( + q2_expected, new_evals[2], + "round {i}: q(2) inconsistent with (c0, c2)" + ); + } + + // Compare challenges (should be identical since same transcript seed). + assert_eq!(old_result.verifier_messages, new_result.challenges); + + // Compare final evaluations. + let (fa, fb) = prover.final_evaluations(); + let (old_fa, old_fb) = old_result.final_evaluations; + assert_eq!(fa, old_fa, "f(r) mismatch"); + assert_eq!(fb, old_fb, "g(r) mismatch"); + + // Final value = f(r) * g(r). + assert_eq!(new_result.final_value, fa * fb); + } +} diff --git a/src/provers/inner_product_lsb.rs b/src/provers/inner_product_lsb.rs new file mode 100644 index 00000000..8121c4ff --- /dev/null +++ b/src/provers/inner_product_lsb.rs @@ -0,0 +1,268 @@ +//! LSB (pair-split) inner-product sumcheck prover: `g = f_tilde * g_tilde`, degree 2. +//! +//! Folds the *least-significant* variable each round: pairs `(f[2k], f[2k+1])`. +//! This is the natural layout for **sequential streaming** where evaluations +//! arrive in index order. +//! +//! Use this prover for Jolt-style workloads. For in-memory or random-access +//! workloads, prefer [`InnerProductProver`](super::inner_product::InnerProductProver) +//! (MSB layout). + +extern crate alloc; +use crate::field::SumcheckField; +#[cfg(feature = "arkworks")] +use crate::sumcheck_prover::SumcheckProver; +use alloc::vec::Vec; + +/// LSB inner-product sumcheck prover (degree 2, pair-split layout). +/// +/// Computes `sum_x f(x) * g(x)` by folding the least-significant variable +/// each round. +/// +/// ```ignore +/// let mut prover = InnerProductProverLSB::new(a, b); +/// let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); +/// let (f_r, g_r) = prover.final_evaluations(); +/// ``` +pub struct InnerProductProverLSB { + a: Vec, + b: Vec, +} + +impl InnerProductProverLSB { + pub fn new(a: Vec, b: Vec) -> Self { + assert_eq!(a.len(), b.len(), "a and b must have equal length"); + Self { a, b } + } + + pub fn evaluations(&self) -> (&[F], &[F]) { + (&self.a, &self.b) + } + + pub fn final_evaluations(&self) -> (F, F) { + if self.a.len() == 1 { + (self.a[0], self.b[0]) + } else { + (F::ZERO, F::ZERO) + } + } +} + +// ─── LSB fold and compute ────────────────────────────────────────────────── + +/// Compute round polynomial evaluations at {0, 1, 2} from LSB pair-split layout. +/// +/// q(0) = sum a[2k] * b[2k] +/// q(1) = sum a[2k+1] * b[2k+1] +/// q(2) = sum (2*a[2k+1] - a[2k]) * (2*b[2k+1] - b[2k]) +fn compute_lsb(a: &[F], b: &[F]) -> (F, F, F) { + debug_assert_eq!(a.len(), b.len()); + if a.is_empty() { + return (F::ZERO, F::ZERO, F::ZERO); + } + if a.len() == 1 { + return (a[0] * b[0], F::ZERO, F::ZERO); + } + + let mut q0 = F::ZERO; + let mut q1 = F::ZERO; + let mut q2 = F::ZERO; + for i in (0..a.len()).step_by(2) { + let a_even = a[i]; + let a_odd = a[i + 1]; + let b_even = b[i]; + let b_odd = b[i + 1]; + q0 += a_even * b_even; + q1 += a_odd * b_odd; + // f(2) = 2*a_odd - a_even, g(2) = 2*b_odd - b_even + let a2 = a_odd + a_odd - a_even; + let b2 = b_odd + b_odd - b_even; + q2 += a2 * b2; + } + (q0, q1, q2) +} + +/// In-place LSB fold: `new[k] = f[2k] + w * (f[2k+1] - f[2k])`. +fn fold_lsb(v: &mut Vec, weight: F) { + if v.len() <= 1 { + return; + } + let new_len = v.len() / 2; + for i in 0..new_len { + let a = v[2 * i]; + let b = v[2 * i + 1]; + v[i] = a + weight * (b - a); + } + v.truncate(new_len); +} + +/// Fused fold + compute: fold both vectors with `weight`, then compute +/// the next round's (q0, q1, q2) in one pass over quads. +fn fused_fold_and_compute_lsb( + a: &mut Vec, + b: &mut Vec, + weight: F, +) -> (F, F, F) { + let n = a.len(); + debug_assert_eq!(n, b.len()); + if n < 4 { + fold_lsb(a, weight); + fold_lsb(b, weight); + return compute_lsb(a, b); + } + + let new_len = n / 2; + let mut q0 = F::ZERO; + let mut q1 = F::ZERO; + let mut q2 = F::ZERO; + + // Process quads: indices (4k, 4k+1, 4k+2, 4k+3) + // Fold produces: new_a[2k] = a[4k] + w*(a[4k+1] - a[4k]) + // new_a[2k+1] = a[4k+2] + w*(a[4k+3] - a[4k+2]) + // Next round's LSB pairs are (new[2k], new[2k+1]). + let quads = n / 4; + for q in 0..quads { + let i = 4 * q; + let na_even = a[i] + weight * (a[i + 1] - a[i]); + let na_odd = a[i + 2] + weight * (a[i + 3] - a[i + 2]); + let nb_even = b[i] + weight * (b[i + 1] - b[i]); + let nb_odd = b[i + 2] + weight * (b[i + 3] - b[i + 2]); + + a[2 * q] = na_even; + a[2 * q + 1] = na_odd; + b[2 * q] = nb_even; + b[2 * q + 1] = nb_odd; + + q0 += na_even * nb_even; + q1 += na_odd * nb_odd; + let a2 = na_odd + na_odd - na_even; + let b2 = nb_odd + nb_odd - nb_even; + q2 += a2 * b2; + } + + // Handle remainder if new_len is odd. + if new_len > 2 * quads { + let i = 4 * quads; + let na = a[i] + weight * (a[i + 1] - a[i]); + let nb = b[i] + weight * (b[i + 1] - b[i]); + a[2 * quads] = na; + b[2 * quads] = nb; + q0 += na * nb; + } + + a.truncate(new_len); + b.truncate(new_len); + (q0, q1, q2) +} + +// ─── SumcheckProver impl ─────────────────────────────────────────────────── + +#[cfg(feature = "arkworks")] +impl SumcheckProver for InnerProductProverLSB +where + F: ark_ff::Field, +{ + fn degree(&self) -> usize { + 2 + } + + fn round(&mut self, challenge: Option) -> Vec { + let (q0, q1, q2) = if let Some(w) = challenge { + fused_fold_and_compute_lsb(&mut self.a, &mut self.b, w) + } else { + compute_lsb(&self.a, &self.b) + }; + vec![q0, q1, q2] + } + + fn finalize(&mut self, last_challenge: F) { + fold_lsb(&mut self.a, last_challenge); + fold_lsb(&mut self.b, last_challenge); + } + + fn final_value(&self) -> F { + if self.a.len() == 1 { + self.a[0] * self.b[0] + } else { + F::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provers::inner_product::InnerProductProver; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + #[test] + fn lsb_inner_product_completes_and_verifies() { + let mut rng = StdRng::seed_from_u64(42); + let n = 16; + let a: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = a.iter().zip(&b).map(|(&x, &y)| x * y).sum(); + + let mut prover = InnerProductProverLSB::new(a.clone(), b.clone()); + let num_vars = 4; + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_vars, &mut t, |_, _| {}); + + // Round-0 consistency. + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + claimed_sum + ); + + // All-round consistency via Lagrange on {0, 1, 2}. + let mut claim = claimed_sum; + for (rp, &r) in proof.round_polys.iter().zip(&proof.challenges) { + assert_eq!(rp.len(), 3); + assert_eq!(rp[0] + rp[1], claim, "consistency check failed"); + let two = F64::from(2u64); + let l0 = (r - F64::from(1u64)) * (r - two) / two; + let l1 = -r * (r - two); + let l2 = r * (r - F64::from(1u64)) / two; + claim = rp[0] * l0 + rp[1] * l1 + rp[2] * l2; + } + assert_eq!(proof.final_value, claim); + + // Final evaluations. + let (fa, fb) = prover.final_evaluations(); + assert_eq!(proof.final_value, fa * fb); + } + + #[test] + fn lsb_and_msb_prove_same_sum() { + let mut rng = StdRng::seed_from_u64(42); + let n = 32; + let a: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = a.iter().zip(&b).map(|(&x, &y)| x * y).sum(); + + // LSB. + let mut lsb = InnerProductProverLSB::new(a.clone(), b.clone()); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let lsb_proof = sumcheck(&mut lsb, 5, &mut t, |_, _| {}); + assert_eq!( + lsb_proof.round_polys[0][0] + lsb_proof.round_polys[0][1], + claimed_sum + ); + + // MSB. + let mut msb = InnerProductProver::new(a, b); + let mut trng2 = StdRng::seed_from_u64(99); + let mut t2 = SanityTranscript::new(&mut trng2); + let msb_proof = sumcheck(&mut msb, 5, &mut t2, |_, _| {}); + assert_eq!( + msb_proof.round_polys[0][0] + msb_proof.round_polys[0][1], + claimed_sum + ); + } +} diff --git a/src/provers/mod.rs b/src/provers/mod.rs new file mode 100644 index 00000000..14e965ea --- /dev/null +++ b/src/provers/mod.rs @@ -0,0 +1,15 @@ +//! Concrete [`SumcheckProver`](crate::sumcheck_prover::SumcheckProver) +//! implementations for each polynomial shape. + +#[cfg(feature = "arkworks")] +pub mod coefficient; +#[cfg(feature = "arkworks")] +pub mod coefficient_lsb; +#[cfg(feature = "arkworks")] +pub mod gkr; +#[cfg(feature = "arkworks")] +pub mod inner_product; +pub mod inner_product_lsb; +#[cfg(feature = "arkworks")] +pub mod multilinear; +pub mod multilinear_lsb; diff --git a/src/provers/multilinear.rs b/src/provers/multilinear.rs new file mode 100644 index 00000000..f0d067d2 --- /dev/null +++ b/src/provers/multilinear.rs @@ -0,0 +1,137 @@ +//! Multilinear sumcheck prover: `g = f_tilde`, degree 1. +//! +//! Wraps the fused fold+compute kernel from `multilinear_sumcheck.rs` +//! behind the [`SumcheckProver`] trait. + +use crate::field::SumcheckField; +use crate::multilinear_sumcheck::{ + compute_sumcheck_polynomial, fold, fused_fold_and_compute_polynomial, +}; +use crate::sumcheck_prover::SumcheckProver; + +/// Multilinear sumcheck prover (degree 1). +/// +/// Computes `∑_x v(x)` where `v` is a multilinear polynomial specified +/// by its evaluations over the Boolean hypercube. +/// +/// # Construction +/// +/// ```ignore +/// // Time strategy: O(2^v) space, O(2^v) total time. +/// let mut prover = MultilinearProver::new(evals); +/// let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); +/// ``` +pub struct MultilinearProver { + evals: Vec, +} + +impl MultilinearProver { + /// Time strategy prover: holds all evaluations in memory. + pub fn new(evals: Vec) -> Self { + Self { evals } + } + + /// Number of variables (log2 of evaluation count, rounded up). + pub fn num_variables(&self) -> usize { + if self.evals.is_empty() { + 0 + } else { + self.evals.len().next_power_of_two().trailing_zeros() as usize + } + } + + /// Access the (possibly folded) evaluation table. + pub fn evals(&self) -> &[F] { + &self.evals + } +} + +// NOTE: The `ark_ff::Field` bound is temporary — required because the +// underlying functions in `multilinear_sumcheck.rs` use `F: Field`. +// It will be removed when those functions are ported to `SumcheckField`. +#[cfg(feature = "arkworks")] +impl SumcheckProver for MultilinearProver +where + F: ark_ff::Field, +{ + fn degree(&self) -> usize { + 1 + } + + fn round(&mut self, challenge: Option) -> Vec { + let (s0, s1) = if let Some(w) = challenge { + fused_fold_and_compute_polynomial(&mut self.evals, w) + } else { + compute_sumcheck_polynomial(&self.evals) + }; + vec![s0, s1] + } + + fn finalize(&mut self, last_challenge: F) { + fold(&mut self.evals, last_challenge); + } + + fn final_value(&self) -> F { + if self.evals.len() == 1 { + self.evals[0] + } else { + F::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + /// New `MultilinearProver` API produces the same proof as the old + /// `multilinear_sumcheck` function. + #[test] + fn matches_legacy_multilinear_sumcheck() { + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..16).map(|_| F64::rand(&mut rng)).collect(); + + // Old API. + let mut old_evals = evals.clone(); + let mut trng = StdRng::seed_from_u64(99); + let mut t_old = SanityTranscript::new(&mut trng); + let old_result = crate::multilinear_sumcheck::multilinear_sumcheck( + &mut old_evals, + &mut t_old, + |_, _| {}, + ); + + // New API. + let mut prover = MultilinearProver::new(evals); + let num_rounds = prover.num_variables(); + let mut trng2 = StdRng::seed_from_u64(99); + let mut t_new = SanityTranscript::new(&mut trng2); + let new_result = sumcheck(&mut prover, num_rounds, &mut t_new, |_, _| {}); + + // Compare round polynomials. + assert_eq!( + old_result.prover_messages.len(), + new_result.round_polys.len() + ); + for (i, (old_msg, new_evals)) in old_result + .prover_messages + .iter() + .zip(&new_result.round_polys) + .enumerate() + { + assert_eq!(old_msg.0, new_evals[0], "round {i}: s0 mismatch"); + assert_eq!(old_msg.1, new_evals[1], "round {i}: s1 mismatch"); + } + + // Compare challenges. + assert_eq!(old_result.verifier_messages, new_result.challenges); + + // Compare final value. + assert_eq!(old_result.final_evaluation, new_result.final_value); + } +} diff --git a/src/provers/multilinear_lsb.rs b/src/provers/multilinear_lsb.rs new file mode 100644 index 00000000..3fcd8157 --- /dev/null +++ b/src/provers/multilinear_lsb.rs @@ -0,0 +1,267 @@ +//! LSB (pair-split) multilinear sumcheck prover: `g = f_tilde`, degree 1. +//! +//! Folds the *least-significant* variable each round: pairs `(f[2k], f[2k+1])`. +//! This is the natural layout for **sequential streaming** where evaluations +//! arrive in index order — adjacent pairs are immediately available. +//! +//! Use this prover for Jolt-style workloads where the witness is generated +//! incrementally (CPU trace). For in-memory or random-access-streaming +//! workloads, prefer [`MultilinearProver`](super::multilinear::MultilinearProver) +//! (MSB layout). + +extern crate alloc; +use crate::field::SumcheckField; +use crate::sumcheck_prover::SumcheckProver; +use alloc::vec; +use alloc::vec::Vec; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// LSB multilinear sumcheck prover (degree 1, pair-split layout). +/// +/// Computes `sum_x v(x)` by folding the least-significant variable each round. +/// +/// # Construction +/// +/// ```ignore +/// let mut prover = MultilinearProverLSB::new(evals); +/// let proof = sumcheck(&mut prover, num_rounds, &mut transcript, |_, _| {}); +/// ``` +pub struct MultilinearProverLSB { + evals: Vec, +} + +impl MultilinearProverLSB { + /// Time strategy prover with LSB (pair-split) layout. + pub fn new(evals: Vec) -> Self { + Self { evals } + } + + /// Number of variables. + pub fn num_variables(&self) -> usize { + if self.evals.is_empty() { + 0 + } else { + self.evals.len().next_power_of_two().trailing_zeros() as usize + } + } + + /// Access the (possibly folded) evaluation table. + pub fn evals(&self) -> &[F] { + &self.evals + } +} + +// ─── LSB fold and compute ────────────────────────────────────────────────── + +/// Compute (s0, s1) for the LSB round polynomial from pair-split layout. +/// +/// `s0 = sum of even-indexed elements = sum f[2k]` +/// `s1 = sum of odd-indexed elements = sum f[2k+1]` +fn compute_lsb(evals: &[F]) -> (F, F) { + if evals.is_empty() { + return (F::ZERO, F::ZERO); + } + if evals.len() == 1 { + return (evals[0], F::ZERO); + } + + #[cfg(feature = "parallel")] + { + const PARALLEL_THRESHOLD: usize = 1 << 14; + if evals.len() > PARALLEL_THRESHOLD { + let (s0, s1) = evals + .par_chunks(2) + .map(|chunk| { + let a = chunk[0]; + let b = if chunk.len() > 1 { chunk[1] } else { F::ZERO }; + (a, b) + }) + .reduce( + || (F::ZERO, F::ZERO), + |(a0, a1), (b0, b1)| (a0 + b0, a1 + b1), + ); + return (s0, s1); + } + } + + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for chunk in evals.chunks(2) { + s0 += chunk[0]; + if chunk.len() > 1 { + s1 += chunk[1]; + } + } + (s0, s1) +} + +/// In-place LSB (pair-split) fold: `new[k] = f[2k] + w * (f[2k+1] - f[2k])`. +fn fold_lsb(evals: &mut Vec, weight: F) { + if evals.len() <= 1 { + return; + } + let new_len = evals.len() / 2; + + #[cfg(feature = "parallel")] + { + const PARALLEL_THRESHOLD: usize = 1 << 14; + if evals.len() > PARALLEL_THRESHOLD { + let out: Vec = evals + .par_chunks(2) + .map(|chunk| chunk[0] + weight * (chunk[1] - chunk[0])) + .collect(); + *evals = out; + return; + } + } + + for i in 0..new_len { + let a = evals[2 * i]; + let b = evals[2 * i + 1]; + evals[i] = a + weight * (b - a); + } + evals.truncate(new_len); +} + +/// Fused fold + compute: fold with `weight`, then compute the next round's +/// (s0, s1) from the folded data. Single pass over pairs of pairs. +fn fused_fold_and_compute_lsb(evals: &mut Vec, weight: F) -> (F, F) { + let n = evals.len(); + if n < 4 { + fold_lsb(evals, weight); + return compute_lsb(evals); + } + + let new_len = n / 2; + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + + // Process quads: (f[4k], f[4k+1], f[4k+2], f[4k+3]) + // Fold produces: new[2k] = f[4k] + w*(f[4k+1]-f[4k]) + // new[2k+1] = f[4k+2] + w*(f[4k+3]-f[4k+2]) + // Next round: s0 += new[2k], s1 += new[2k+1] + let quads = n / 4; + for q in 0..quads { + let i = 4 * q; + let a = evals[i] + weight * (evals[i + 1] - evals[i]); + let b = evals[i + 2] + weight * (evals[i + 3] - evals[i + 2]); + evals[2 * q] = a; + evals[2 * q + 1] = b; + s0 += a; + s1 += b; + } + + // Handle remainder if new_len is odd (original n not divisible by 4). + if new_len > 2 * quads { + let i = 4 * quads; + let a = evals[i] + weight * (evals[i + 1] - evals[i]); + evals[2 * quads] = a; + s0 += a; + } + + evals.truncate(new_len); + (s0, s1) +} + +// ─── SumcheckProver impl ─────────────────────────────────────────────────── + +impl SumcheckProver for MultilinearProverLSB { + fn degree(&self) -> usize { + 1 + } + + fn round(&mut self, challenge: Option) -> Vec { + let (s0, s1) = if let Some(w) = challenge { + fused_fold_and_compute_lsb(&mut self.evals, w) + } else { + compute_lsb(&self.evals) + }; + vec![s0, s1] + } + + fn finalize(&mut self, last_challenge: F) { + fold_lsb(&mut self.evals, last_challenge); + } + + fn final_value(&self) -> F { + if self.evals.len() == 1 { + self.evals[0] + } else { + F::ZERO + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provers::multilinear::MultilinearProver; + use crate::runner::sumcheck; + use crate::tests::F64; + use crate::transcript::SanityTranscript; + use ark_ff::UniformRand; + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + /// LSB and MSB provers produce different round polynomials (they fold + /// different variables) but the same final value when evaluated at the + /// same random point via independent MLE evaluation. + #[test] + fn lsb_prover_completes_and_verifies() { + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..16).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let mut prover = MultilinearProverLSB::new(evals); + let num_vars = prover.num_variables(); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_vars, &mut t, |_, _| {}); + + // Round-0 consistency. + assert_eq!( + proof.round_polys[0][0] + proof.round_polys[0][1], + claimed_sum + ); + + // All-round consistency. + let mut claim = claimed_sum; + for (rp, &r) in proof.round_polys.iter().zip(&proof.challenges) { + assert_eq!(rp[0] + rp[1], claim, "consistency check failed"); + claim = rp[0] + r * (rp[1] - rp[0]); + } + + // Final value matches claim after all rounds. + assert_eq!(proof.final_value, claim); + assert_eq!(prover.evals().len(), 1); + } + + /// LSB and MSB produce the same claimed sum (both prove the same statement). + #[test] + fn lsb_and_msb_prove_same_sum() { + let mut rng = StdRng::seed_from_u64(42); + let evals: Vec = (0..32).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + // LSB. + let mut lsb = MultilinearProverLSB::new(evals.clone()); + let mut trng = StdRng::seed_from_u64(99); + let mut t = SanityTranscript::new(&mut trng); + let lsb_proof = sumcheck(&mut lsb, 5, &mut t, |_, _| {}); + assert_eq!( + lsb_proof.round_polys[0][0] + lsb_proof.round_polys[0][1], + claimed_sum + ); + + // MSB. + let mut msb = MultilinearProver::new(evals); + let mut trng2 = StdRng::seed_from_u64(99); + let mut t2 = SanityTranscript::new(&mut trng2); + let msb_proof = sumcheck(&mut msb, 5, &mut t2, |_, _| {}); + assert_eq!( + msb_proof.round_polys[0][0] + msb_proof.round_polys[0][1], + claimed_sum + ); + } +} diff --git a/src/reductions/mod.rs b/src/reductions/mod.rs new file mode 100644 index 00000000..fb607951 --- /dev/null +++ b/src/reductions/mod.rs @@ -0,0 +1,8 @@ +//! Reduction utilities for sumcheck prover implementations. +//! +//! Pairwise (LSB) and tablewise reductions used by the coefficient sumcheck. + +#[allow(dead_code)] +pub mod pairwise; +#[allow(dead_code)] +pub mod tablewise; diff --git a/src/reductions/pairwise.rs b/src/reductions/pairwise.rs new file mode 100644 index 00000000..0de86c99 --- /dev/null +++ b/src/reductions/pairwise.rs @@ -0,0 +1,123 @@ +use ark_ff::Field; +use ark_std::vec::Vec; +use ark_std::{cfg_chunks, cfg_into_iter}; +#[cfg(feature = "parallel")] +use rayon::{ + iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, + prelude::ParallelSlice, +}; + +use crate::streams::Stream; + +pub fn evaluate(src: &[F]) -> (F, F) { + let even_sum = cfg_into_iter!(0..src.len()) + .step_by(2) + .map(|i| src[i]) + .sum(); + let odd_sum = cfg_into_iter!(1..src.len()) + .step_by(2) + .map(|i| src[i]) + .sum(); + (even_sum, odd_sum) +} + +pub fn evaluate_from_stream>(src: &S) -> (F, F) { + let len = 1usize << src.num_variables(); + let even_sum = cfg_into_iter!(0..len) + .step_by(2) + .map(|i| src.evaluation(i)) + .sum(); + let odd_sum = cfg_into_iter!(1..len) + .step_by(2) + .map(|i| src.evaluation(i)) + .sum(); + (even_sum, odd_sum) +} + +pub fn reduce_evaluations(src: &mut Vec, verifier_message: F) { + /// Below this input size, the serial in-place path wins: rayon's + /// fork/join overhead exceeds the actual compute, and we avoid the + /// `.collect()` allocation entirely. Above it, parallelism outpaces + /// serial even with the allocation cost. Chosen to match typical L1 + /// cache-blocking on modern SIMD hosts (~4K field elements). + const SERIAL_THRESHOLD: usize = 1 << 12; + + #[cfg(feature = "parallel")] + { + if src.len() > SERIAL_THRESHOLD { + // Parallel path: MSB pairing makes true in-place parallel + // impossible without unsafe (writer i writes src[i] while writer + // j reads src[2j] which may alias src[i]). Allocate a fresh Vec + // via rayon's parallel collect; `*src = out` swaps buffers + // without a copy. + let out: Vec = cfg_chunks!(src, 2) + .map(|chunk| chunk[0] + verifier_message * (chunk[1] - chunk[0])) + .collect(); + *src = out; + return; + } + } + + // Serial path: truly in-place. Writing src[i] while reading src[2i] and + // src[2i+1] is safe sequentially because 2i ≥ i always, so we never + // clobber a read we still need. Used for non-parallel builds and for + // small inputs where rayon overhead would dominate. + let new_len = src.len() / 2; + for i in 0..new_len { + let a = src[2 * i]; + let b = src[2 * i + 1]; + src[i] = a + verifier_message * (b - a); + } + src.truncate(new_len); +} + +pub fn reduce_evaluations_from_stream>( + src: &S, + dst: &mut Vec, + verifier_message: F, +) { + // compute from stream + let len = 1usize << src.num_variables(); + let out: Vec = cfg_into_iter!(0..len / 2) + .map(|i| { + let a = src.evaluation(2 * i); + let b = src.evaluation((2 * i) + 1); + a + verifier_message * (b - a) + }) + .collect(); + *dst = out; +} + +/// Pairwise product evaluate returning coefficients `(a, b)` of the degree-2 +/// round polynomial `q(x) = a + bx + cx²`: +/// - `a = Σ f_even · g_even` +/// - `b = Σ (f_even · g_odd + f_odd · g_even)` +pub fn pairwise_product_evaluate(src: &[Vec]) -> (F, F) { + let half_len = src[0].len() / 2; + let a: F = cfg_into_iter!(0..half_len) + .map(|k| { + let i = 2 * k; + src[0][i] * src[1][i] + }) + .sum(); + let b: F = cfg_into_iter!(0..half_len) + .map(|k| { + let i = 2 * k; + src[0][i] * src[1][i + 1] + src[0][i + 1] * src[1][i] + }) + .sum(); + (a, b) +} + +/// Cross-field reduce: fold `BF` evaluations with an `EF` challenge, producing `Vec`. +/// +/// For each adjacent pair `(a, b)` in `src`: `EF::from(a) + challenge * (EF::from(b) - EF::from(a))`. +pub fn cross_field_reduce>(src: &[BF], challenge: EF) -> Vec { + cfg_chunks!(src, 2) + .map(|chunk| { + let a = EF::from(chunk[0]); + let b = EF::from(chunk[1]); + a + challenge * (b - a) + }) + .collect() +} diff --git a/src/reductions/tablewise.rs b/src/reductions/tablewise.rs new file mode 100644 index 00000000..34d388b1 --- /dev/null +++ b/src/reductions/tablewise.rs @@ -0,0 +1,17 @@ +use ark_ff::Field; +use ark_std::{cfg_chunks, vec::Vec}; +#[cfg(feature = "parallel")] +use rayon::{iter::ParallelIterator, prelude::ParallelSlice}; + +pub fn reduce_evaluations(src: &mut Vec>, verifier_message: F) { + let out: Vec> = cfg_chunks!(src, 2) + .map(|chunk| { + chunk[0] + .iter() + .zip(&chunk[1]) + .map(|(&a, &b)| a + verifier_message * (b - a)) + .collect::>() + }) + .collect(); + *src = out; +} diff --git a/src/runner.rs b/src/runner.rs new file mode 100644 index 00000000..06efce0c --- /dev/null +++ b/src/runner.rs @@ -0,0 +1,69 @@ +//! Protocol runner for the sum-check protocol (Thaler Proposition 4.1). +//! +//! [`sumcheck()`] drives a [`SumcheckProver`] through `num_rounds` rounds, +//! writing round polynomials to the transcript, invoking a per-round hook, +//! and reading verifier challenges. It returns a [`SumcheckProof`] containing +//! the round polynomials, challenges, and the prover's final claimed value. +//! +//! Partial execution (`num_rounds < v`) supports composed protocols like +//! GKR (one sumcheck per layer) and WHIR (partial rounds interleaved with +//! commit/open). + +extern crate alloc; +use crate::field::SumcheckField; +use crate::proof::SumcheckProof; +use crate::sumcheck_prover::SumcheckProver; +use crate::transcript::ProverTranscript; +use alloc::vec::Vec; + +/// Run the sum-check protocol for `num_rounds` rounds. +/// +/// `hook` is called each round after the prover message is written and +/// before the verifier challenge is read. Pass `|_, _| {}` when no hook +/// is needed. +/// +/// On return the prover has been advanced through `num_rounds` folds. +/// If `num_rounds == v` (full execution), `proof.final_value` is the +/// prover's claimed evaluation at the random point. For partial execution, +/// the caller retains `prover` and can continue or inspect post-state. +pub fn sumcheck>( + prover: &mut impl SumcheckProver, + num_rounds: usize, + transcript: &mut T, + mut hook: impl FnMut(usize, &mut T), +) -> SumcheckProof { + let mut round_polys: Vec> = Vec::with_capacity(num_rounds); + let mut challenges: Vec = Vec::with_capacity(num_rounds); + let mut prev_challenge: Option = None; + + for round in 0..num_rounds { + let evals = prover.round(prev_challenge); + + // Send evaluations to transcript. + for &v in &evals { + transcript.send(v); + } + round_polys.push(evals); + + // Per-round hook (e.g., proof-of-work grinding for WHIR). + hook(round, transcript); + + // Squeeze verifier challenge. + let r = transcript.challenge(); + challenges.push(r); + prev_challenge = Some(r); + } + + // Apply the final challenge so final_value() is correct. + if let Some(r) = prev_challenge { + prover.finalize(r); + } + + let final_value = prover.final_value(); + + SumcheckProof { + round_polys, + challenges, + final_value, + } +} diff --git a/src/simd_fields/goldilocks/avx512.rs b/src/simd_fields/goldilocks/avx512.rs new file mode 100644 index 00000000..608ee8b8 --- /dev/null +++ b/src/simd_fields/goldilocks/avx512.rs @@ -0,0 +1,977 @@ +#![allow(dead_code)] +//! Montgomery-form Goldilocks AVX-512 IFMA backend. +//! +//! Operates directly on Montgomery-form values (as stored by arkworks `Fp64`), +//! enabling zero-cost `transmute` from `&[F64]` to `&[u64]`. +//! +//! Uses AVX-512 IFMA (52-bit multiply-accumulate) for true 8-wide vectorized +//! Montgomery multiplication. Unlike the NEON backend (which falls back to +//! scalar mont_mul per lane because NEON lacks 64x64->128 multiply), this +//! backend decomposes operands into 52-bit limbs and uses `vpmadd52luq` / +//! `vpmadd52huq` for a fully vectorized schoolbook multiply. Montgomery +//! reduction exploits the Goldilocks prime structure (P = 2^64 - 2^32 + 1) +//! to avoid additional IFMA multiplies — only shifts, adds, and subtracts. + +use core::arch::x86_64::*; + +use super::super::SimdBaseField; + +/// Goldilocks modulus: P = 2^64 - 2^32 + 1. +const P: u64 = 0xFFFF_FFFF_0000_0001; + +/// ε = 2^64 mod P = 2^32 - 1 (used for add/sub overflow correction). +const EPSILON: u64 = 0xFFFF_FFFF; + +/// Montgomery constant: INV = -P^{-1} mod 2^64. +const INV: u64 = 0xFFFF_FFFE_FFFF_FFFF; + +/// Montgomery ONE = R mod P = 2^64 mod P = EPSILON. +const MONT_ONE: u64 = EPSILON; + +/// Montgomery ZERO = 0 (same in both domains). +const MONT_ZERO: u64 = 0; + +/// Mask for lower 52 bits (IFMA operand width). +const MASK52: u64 = (1u64 << 52) - 1; + +#[derive(Copy, Clone)] +pub struct GoldilocksAvx512; + +impl SimdBaseField for GoldilocksAvx512 { + type Scalar = u64; + type Packed = __m512i; + const LANES: usize = 8; + const MODULUS: u64 = P; + const ZERO: u64 = MONT_ZERO; + const ONE: u64 = MONT_ONE; + + #[inline(always)] + fn splat(val: u64) -> __m512i { + unsafe { _mm512_set1_epi64(val as i64) } + } + + #[inline(always)] + unsafe fn load(ptr: *const u64) -> __m512i { + unsafe { _mm512_loadu_si512(ptr.cast()) } + } + + #[inline(always)] + unsafe fn store(ptr: *mut u64, v: __m512i) { + unsafe { _mm512_storeu_si512(ptr.cast(), v) } + } + + // Add/sub are identical in canonical and Montgomery domain. + + #[inline(always)] + fn add(a: __m512i, b: __m512i) -> __m512i { + unsafe { + let sum = _mm512_add_epi64(a, b); + let p_vec = _mm512_set1_epi64(P as i64); + let eps_vec = _mm512_set1_epi64(EPSILON as i64); + + // Detect unsigned overflow: sum < a means carry occurred + let carry = _mm512_cmplt_epu64_mask(sum, a); + // Detect sum >= P (only relevant when no carry) + let ge_p = !_mm512_cmplt_epu64_mask(sum, p_vec); // >= is NOT < + + // Carry path: sum + ε (2^64 ≡ ε mod P, result guaranteed < P) + let result = _mm512_mask_add_epi64(sum, carry, sum, eps_vec); + // No-carry, >= P path: sum - P + let need_sub = ge_p & !carry; + _mm512_mask_sub_epi64(result, need_sub, result, p_vec) + } + } + + #[inline(always)] + fn sub(a: __m512i, b: __m512i) -> __m512i { + unsafe { + let diff = _mm512_sub_epi64(a, b); + let p_vec = _mm512_set1_epi64(P as i64); + // Borrow when a < b (unsigned) + let borrow = _mm512_cmplt_epu64_mask(a, b); + _mm512_mask_add_epi64(diff, borrow, diff, p_vec) + } + } + + #[inline(always)] + fn mul(a: __m512i, b: __m512i) -> __m512i { + // True 8-wide Montgomery multiplication via IFMA 52-bit decomposition. + // + // 1. Schoolbook 64×64→128 product using 52-bit limbs + IFMA + // 2. Montgomery reduction factor m via Goldilocks structure: + // INV = -(2^32+1) mod 2^64, so m = -(lo + lo<<32) — no multiply + // 3. m*P via P = 2^64 - 2^32 + 1 — shifts and subtracts only + // 4. result = (product + m*P) >> 64, conditional subtract P + unsafe { avx512_mont_mul(a, b) } + } + + #[inline(always)] + fn add_wrapping(a: __m512i, b: __m512i) -> __m512i { + unsafe { _mm512_add_epi64(a, b) } + } + + #[inline(always)] + fn carry_mask(sum: __m512i, a_before: __m512i) -> __m512i { + unsafe { + let carry = _mm512_cmplt_epu64_mask(sum, a_before); + _mm512_maskz_set1_epi64(carry, 1) + } + } + + #[inline(always)] + fn reduce_carry(sum: __m512i, carry_count: __m512i) -> __m512i { + // Each carry represents 2^64 ≡ EPSILON (mod P). + // correction = carry_count * EPSILON (fits in u64 for reasonable counts). + unsafe { + let eps_vec = _mm512_set1_epi64(EPSILON as i64); + let correction = _mm512_mullo_epi64(carry_count, eps_vec); + Self::add(sum, correction) + } + } + + #[inline(always)] + unsafe fn load_deinterleaved(ptr: *const u64) -> (__m512i, __m512i) { + unsafe { + let v0 = _mm512_loadu_si512(ptr.cast()); // [a0,b0,a1,b1,a2,b2,a3,b3] + let v1 = _mm512_loadu_si512(ptr.add(8).cast()); // [a4,b4,a5,b5,a6,b6,a7,b7] + let idx_even = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + let idx_odd = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + let evens = _mm512_permutex2var_epi64(v0, idx_even, v1); + let odds = _mm512_permutex2var_epi64(v0, idx_odd, v1); + (evens, odds) + } + } + + #[inline(always)] + fn scalar_add(a: u64, b: u64) -> u64 { + let (sum, carry) = a.overflowing_add(b); + if carry { + sum + EPSILON + } else if sum >= P { + sum - P + } else { + sum + } + } + + #[inline(always)] + fn scalar_sub(a: u64, b: u64) -> u64 { + if a >= b { + a - b + } else { + a.wrapping_sub(b).wrapping_add(P) + } + } + + #[inline(always)] + fn scalar_mul(a: u64, b: u64) -> u64 { + mont_mul(a, b) + } +} + +/// AVX-512 IFMA Montgomery multiplication (8-wide). +/// +/// Decomposes each 64-bit operand into two 52-bit limbs, performs a +/// schoolbook multiply using `vpmadd52luq`/`vpmadd52huq` (6 IFMA ops), +/// then reduces via the Goldilocks prime structure using only shifts, +/// adds, and masked operations — no additional multiplies needed. +#[inline(always)] +unsafe fn avx512_mont_mul(a: __m512i, b: __m512i) -> __m512i { + let zero = _mm512_setzero_si512(); + let mask52_vec = _mm512_set1_epi64(MASK52 as i64); + let p_vec = _mm512_set1_epi64(P as i64); + let ones = _mm512_set1_epi64(1); + + // ── Decompose into 52-bit limbs ── + let a0 = _mm512_and_si512(a, mask52_vec); // low 52 bits + let a1 = _mm512_srli_epi64(a, 52); // high 12 bits + let b0 = _mm512_and_si512(b, mask52_vec); + let b1 = _mm512_srli_epi64(b, 52); + + // ── Schoolbook multiply in base-2^52 (6 IFMA ops) ── + // Limb 0: lo52(a0*b0) — exactly 52 bits + let c0 = _mm512_madd52lo_epu64(zero, a0, b0); + + // Limb 1: hi52(a0*b0) + lo52(a0*b1) + lo52(a1*b0) — up to ~54 bits + let c1 = _mm512_madd52hi_epu64(zero, a0, b0); + let c1 = _mm512_madd52lo_epu64(c1, a0, b1); + let c1 = _mm512_madd52lo_epu64(c1, a1, b0); + + // Limb 2: hi52(a0*b1) + hi52(a1*b0) + lo52(a1*b1) — up to ~25 bits + let c2 = _mm512_madd52hi_epu64(zero, a0, b1); + let c2 = _mm512_madd52hi_epu64(c2, a1, b0); + let c2 = _mm512_madd52lo_epu64(c2, a1, b1); + + // ── Carry propagation: c1 → c2 ── + let carry = _mm512_srli_epi64(c1, 52); + let c1 = _mm512_and_si512(c1, mask52_vec); // now exactly 52 bits + let c2 = _mm512_add_epi64(c2, carry); + + // ── Reconstruct (lo64, hi64) of the 128-bit product ── + // lo64 = c0[0:51] | c1[0:11] << 52 + let lo = _mm512_or_si512(c0, _mm512_slli_epi64(c1, 52)); + // hi64 = c1[12:51] | c2 << 40 (non-overlapping since c1>>12 is 40 bits) + let hi = _mm512_or_si512(_mm512_srli_epi64(c1, 12), _mm512_slli_epi64(c2, 40)); + + // ── Montgomery reduction using Goldilocks structure ── + // + // m = lo * INV mod 2^64 + // INV = -(2^32 + 1) mod 2^64, so m = -(lo + lo<<32) — no multiply! + let lo_shl32 = _mm512_slli_epi64(lo, 32); + let temp = _mm512_add_epi64(lo, lo_shl32); + let m = _mm512_sub_epi64(zero, temp); + + // m * P where P = 2^64 - 2^32 + 1: + // m*P = m*2^64 + m*(1 - 2^32) + // lo(m*P) = (m - m<<32) mod 2^64 + // hi(m*P) = m - (m>>32) - borrow_from_lo + // + // The m*2^32 term spans two 64-bit words: hi = m>>32, lo = m<<32. + let m_shl32 = _mm512_slli_epi64(m, 32); + let m_shr32 = _mm512_srli_epi64(m, 32); + let borrow_mask = _mm512_cmplt_epu64_mask(m, m_shl32); + let hi_mp = _mm512_sub_epi64(m, m_shr32); + let hi_mp = _mm512_mask_sub_epi64(hi_mp, borrow_mask, hi_mp, ones); + + // result = (product + m*P) >> 64 + // Since lo + lo(m*P) ≡ 0 mod 2^64 by construction, the carry is (lo != 0). + let lo_nonzero = !_mm512_cmpeq_epu64_mask(lo, zero); + let carry_from_lo = _mm512_maskz_set1_epi64(lo_nonzero, 1); + + // r = hi + hi(m*P) + carry + let r1 = _mm512_add_epi64(hi, hi_mp); + let c2_mask = _mm512_cmplt_epu64_mask(r1, hi); // overflow from first add + + let r2 = _mm512_add_epi64(r1, carry_from_lo); + let c3_mask = _mm512_cmplt_epu64_mask(r2, r1); // overflow from second add + + // ── Final reduction: subtract P if carry or result >= P ── + let ge_p = !_mm512_cmplt_epu64_mask(r2, p_vec); + let need_sub = c2_mask | c3_mask | ge_p; + _mm512_mask_sub_epi64(r2, need_sub, r2, p_vec) +} + +// ── Extension field arithmetic ── +// +// Extension field SIMD multiplication is not part of the SimdBaseField trait — +// it's implemented as free functions because the nonresidue `w` is a runtime +// value (extracted from the arkworks extension field config during dispatch). + +/// Degree-2 Karatsuba: (a0 + a1·X)(b0 + b1·X) mod (X² - w) +/// 3 base muls + 1 mul-by-w + adds. +#[inline(always)] +pub fn ext2_mul(a: [__m512i; 2], b: [__m512i; 2], w: __m512i) -> [__m512i; 2] { + let v0 = GoldilocksAvx512::mul(a[0], b[0]); + let v1 = GoldilocksAvx512::mul(a[1], b[1]); + let c0 = GoldilocksAvx512::add(v0, GoldilocksAvx512::mul(w, v1)); + let a_sum = GoldilocksAvx512::add(a[0], a[1]); + let b_sum = GoldilocksAvx512::add(b[0], b[1]); + let c1 = GoldilocksAvx512::sub( + GoldilocksAvx512::sub(GoldilocksAvx512::mul(a_sum, b_sum), v0), + v1, + ); + [c0, c1] +} + +/// Degree-2 Karatsuba (scalar version for tail processing). +#[inline(always)] +pub fn ext2_scalar_mul(a: [u64; 2], b: [u64; 2], w: u64) -> [u64; 2] { + let v0 = mont_mul(a[0], b[0]); + let v1 = mont_mul(a[1], b[1]); + let c0 = GoldilocksAvx512::scalar_add(v0, mont_mul(w, v1)); + let a_sum = GoldilocksAvx512::scalar_add(a[0], a[1]); + let b_sum = GoldilocksAvx512::scalar_add(b[0], b[1]); + let c1 = + GoldilocksAvx512::scalar_sub(GoldilocksAvx512::scalar_sub(mont_mul(a_sum, b_sum), v0), v1); + [c0, c1] +} + +/// Degree-3 Karatsuba: (a0 + a1·X + a2·X²)(b0 + b1·X + b2·X²) mod (X³ - w) +/// 6 base muls + 2 mul-by-w + adds. +#[inline(always)] +pub fn ext3_mul(a: [__m512i; 3], b: [__m512i; 3], w: __m512i) -> [__m512i; 3] { + let ad = GoldilocksAvx512::mul(a[0], b[0]); + let be = GoldilocksAvx512::mul(a[1], b[1]); + let cf = GoldilocksAvx512::mul(a[2], b[2]); + + let x = GoldilocksAvx512::sub( + GoldilocksAvx512::sub( + GoldilocksAvx512::mul( + GoldilocksAvx512::add(a[1], a[2]), + GoldilocksAvx512::add(b[1], b[2]), + ), + be, + ), + cf, + ); + let y = GoldilocksAvx512::sub( + GoldilocksAvx512::sub( + GoldilocksAvx512::mul( + GoldilocksAvx512::add(a[0], a[1]), + GoldilocksAvx512::add(b[0], b[1]), + ), + ad, + ), + be, + ); + let z = GoldilocksAvx512::add( + GoldilocksAvx512::sub( + GoldilocksAvx512::sub( + GoldilocksAvx512::mul( + GoldilocksAvx512::add(a[0], a[2]), + GoldilocksAvx512::add(b[0], b[2]), + ), + ad, + ), + cf, + ), + be, + ); + + [ + GoldilocksAvx512::add(ad, GoldilocksAvx512::mul(w, x)), + GoldilocksAvx512::add(y, GoldilocksAvx512::mul(w, cf)), + z, + ] +} + +/// Degree-3 Karatsuba (scalar version). +#[inline(always)] +pub fn ext3_scalar_mul(a: [u64; 3], b: [u64; 3], w: u64) -> [u64; 3] { + let ad = mont_mul(a[0], b[0]); + let be = mont_mul(a[1], b[1]); + let cf = mont_mul(a[2], b[2]); + + let x = GoldilocksAvx512::scalar_sub( + GoldilocksAvx512::scalar_sub( + mont_mul( + GoldilocksAvx512::scalar_add(a[1], a[2]), + GoldilocksAvx512::scalar_add(b[1], b[2]), + ), + be, + ), + cf, + ); + let y = GoldilocksAvx512::scalar_sub( + GoldilocksAvx512::scalar_sub( + mont_mul( + GoldilocksAvx512::scalar_add(a[0], a[1]), + GoldilocksAvx512::scalar_add(b[0], b[1]), + ), + ad, + ), + be, + ); + let z = GoldilocksAvx512::scalar_add( + GoldilocksAvx512::scalar_sub( + GoldilocksAvx512::scalar_sub( + mont_mul( + GoldilocksAvx512::scalar_add(a[0], a[2]), + GoldilocksAvx512::scalar_add(b[0], b[2]), + ), + ad, + ), + cf, + ), + be, + ); + + [ + GoldilocksAvx512::scalar_add(ad, mont_mul(w, x)), + GoldilocksAvx512::scalar_add(y, mont_mul(w, cf)), + z, + ] +} + +/// Vectorized ext2 reduce: processes 8 pairs of degree-2 extension elements. +/// +/// Input: 32 u64s in AoS layout: `[a0_c0, a0_c1, b0_c0, b0_c1, a1_c0, ...]` +/// Each group of 4 u64s is one pair `(a_i, b_i)` where a,b are ext2 elements. +/// Computes `result_i = a_i + challenge * (b_i - a_i)` for 8 pairs simultaneously. +/// Output: 16 u64s in AoS layout: `[r0_c0, r0_c1, r1_c0, r1_c1, ...]` +#[inline(always)] +pub unsafe fn ext2_reduce_8pairs( + src: *const u64, + dst: *mut u64, + challenge_c0: __m512i, + challenge_c1: __m512i, + w_vec: __m512i, +) { + // Load 32 u64s (4 cache lines worth) + let v0 = _mm512_loadu_si512(src.cast()); // pairs 0-1: [a0c0,a0c1,b0c0,b0c1, a1c0,a1c1,b1c0,b1c1] + let v1 = _mm512_loadu_si512(src.add(8).cast()); // pairs 2-3 + let v2 = _mm512_loadu_si512(src.add(16).cast()); // pairs 4-5 + let v3 = _mm512_loadu_si512(src.add(24).cast()); // pairs 6-7 + + // Deinterleave: extract a_c0, a_c1, b_c0, b_c1 each as 8-wide vectors. + // Within each 512-bit register, stride is 4: positions 0,4 are a_c0; 1,5 are a_c1; etc. + // Across 4 registers: we gather element [k] from register [k/2], lane [4*(k%2) + component]. + // + // a_c0: from (v0 lane 0), (v0 lane 4), (v1 lane 0), (v1 lane 4), (v2 lane 0), (v2 lane 4), (v3 lane 0), (v3 lane 4) + // This requires cross-register shuffles. Use permutex2var for pairs of registers, + // then a second round. + + // First round: extract even-pair and odd-pair components from adjacent register pairs + // From v0,v1: gather a_c0 at indices 0,4 from v0 (=lanes 0,4) and 0,4 from v1 (=lanes 8,12) + // permutex2var across v0,v1 gives us 8 values; we want the lower 4 from v0 and lower 4 from v1 + // permutex2var treats v0 as indices 0-7 and v1 as indices 8-15 + let a_c0_lo = _mm512_permutex2var_epi64(v0, _mm512_set_epi64(12, 8, 4, 0, 12, 8, 4, 0), v1); + let a_c1_lo = _mm512_permutex2var_epi64(v0, _mm512_set_epi64(13, 9, 5, 1, 13, 9, 5, 1), v1); + let b_c0_lo = _mm512_permutex2var_epi64(v0, _mm512_set_epi64(14, 10, 6, 2, 14, 10, 6, 2), v1); + let b_c1_lo = _mm512_permutex2var_epi64(v0, _mm512_set_epi64(15, 11, 7, 3, 15, 11, 7, 3), v1); + + let a_c0_hi = _mm512_permutex2var_epi64(v2, _mm512_set_epi64(12, 8, 4, 0, 12, 8, 4, 0), v3); + let a_c1_hi = _mm512_permutex2var_epi64(v2, _mm512_set_epi64(13, 9, 5, 1, 13, 9, 5, 1), v3); + let b_c0_hi = _mm512_permutex2var_epi64(v2, _mm512_set_epi64(14, 10, 6, 2, 14, 10, 6, 2), v3); + let b_c1_hi = _mm512_permutex2var_epi64(v2, _mm512_set_epi64(15, 11, 7, 3, 15, 11, 7, 3), v3); + + // Second round: merge lo (pairs 0-3 in lanes 0-3) and hi (pairs 4-7 in lanes 0-3) + // into final 8-wide vectors. + // lo has useful data in lanes 0-3, hi has useful data in lanes 0-3. + // Use permutex2var: take lanes 0-3 from lo (indices 0-3) and lanes 0-3 from hi (indices 8-11). + let idx_merge = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); + + let a_c0 = _mm512_permutex2var_epi64(a_c0_lo, idx_merge, a_c0_hi); + let a_c1 = _mm512_permutex2var_epi64(a_c1_lo, idx_merge, a_c1_hi); + let b_c0 = _mm512_permutex2var_epi64(b_c0_lo, idx_merge, b_c0_hi); + let b_c1 = _mm512_permutex2var_epi64(b_c1_lo, idx_merge, b_c1_hi); + + // Compute diff = b - a (component-wise) + let diff_c0 = GoldilocksAvx512::sub(b_c0, a_c0); + let diff_c1 = GoldilocksAvx512::sub(b_c1, a_c1); + + // prod = challenge * diff (ext2 Karatsuba) + let prod = ext2_mul([diff_c0, diff_c1], [challenge_c0, challenge_c1], w_vec); + + // result = a + prod + let r_c0 = GoldilocksAvx512::add(a_c0, prod[0]); + let r_c1 = GoldilocksAvx512::add(a_c1, prod[1]); + + // Interleave back to AoS: [r0_c0, r0_c1, r1_c0, r1_c1, ...] + // 8 results → 16 u64s in 2 registers + // r_c0 = [r0, r1, r2, r3, r4, r5, r6, r7] (component 0) + // r_c1 = [r0, r1, r2, r3, r4, r5, r6, r7] (component 1) + // Want: out0 = [r0c0,r0c1,r1c0,r1c1,r2c0,r2c1,r3c0,r3c1] + // out1 = [r4c0,r4c1,r5c0,r5c1,r6c0,r6c1,r7c0,r7c1] + let idx_interleave_lo = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + let idx_interleave_hi = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + let out0 = _mm512_permutex2var_epi64(r_c0, idx_interleave_lo, r_c1); + let out1 = _mm512_permutex2var_epi64(r_c0, idx_interleave_hi, r_c1); + + _mm512_storeu_si512(dst.cast(), out0); + _mm512_storeu_si512(dst.add(8).cast(), out1); +} + +/// Vectorized ext3 reduce: processes 8 pairs of degree-3 extension elements. +/// +/// Input: 48 u64s in AoS layout: `[a0_c0, a0_c1, a0_c2, b0_c0, b0_c1, b0_c2, a1_c0, ...]` +/// Each group of 6 u64s is one pair `(a_i, b_i)` where a,b are ext3 elements. +/// Computes `result_i = a_i + challenge * (b_i - a_i)` for 8 pairs simultaneously. +/// Output: 24 u64s in AoS layout: `[r0_c0, r0_c1, r0_c2, r1_c0, r1_c1, r1_c2, ...]` +/// +/// Uses AVX-512 gather/scatter for the stride-6 deinterleave/interleave. +#[inline(always)] +pub unsafe fn ext3_reduce_8pairs( + src: *const u64, + dst: *mut u64, + challenge: [__m512i; 3], + w_vec: __m512i, +) { + // Gather 6 components from AoS layout (stride 6 per pair) + // Pair i: a at offset 6i, b at offset 6i+3 + let idx_a_c0 = _mm512_set_epi64(42, 36, 30, 24, 18, 12, 6, 0); + let idx_a_c1 = _mm512_set_epi64(43, 37, 31, 25, 19, 13, 7, 1); + let idx_a_c2 = _mm512_set_epi64(44, 38, 32, 26, 20, 14, 8, 2); + let idx_b_c0 = _mm512_set_epi64(45, 39, 33, 27, 21, 15, 9, 3); + let idx_b_c1 = _mm512_set_epi64(46, 40, 34, 28, 22, 16, 10, 4); + let idx_b_c2 = _mm512_set_epi64(47, 41, 35, 29, 23, 17, 11, 5); + + let base = src as *const i64; + let a_c0 = _mm512_i64gather_epi64::<8>(idx_a_c0, base); + let a_c1 = _mm512_i64gather_epi64::<8>(idx_a_c1, base); + let a_c2 = _mm512_i64gather_epi64::<8>(idx_a_c2, base); + let b_c0 = _mm512_i64gather_epi64::<8>(idx_b_c0, base); + let b_c1 = _mm512_i64gather_epi64::<8>(idx_b_c1, base); + let b_c2 = _mm512_i64gather_epi64::<8>(idx_b_c2, base); + + // diff = b - a (component-wise) + let diff_c0 = GoldilocksAvx512::sub(b_c0, a_c0); + let diff_c1 = GoldilocksAvx512::sub(b_c1, a_c1); + let diff_c2 = GoldilocksAvx512::sub(b_c2, a_c2); + + // prod = challenge * diff (ext3 Karatsuba) + let prod = ext3_mul([diff_c0, diff_c1, diff_c2], challenge, w_vec); + + // result = a + prod + let r_c0 = GoldilocksAvx512::add(a_c0, prod[0]); + let r_c1 = GoldilocksAvx512::add(a_c1, prod[1]); + let r_c2 = GoldilocksAvx512::add(a_c2, prod[2]); + + // Scatter back to AoS (stride 3 per result element) + let idx_r_c0 = _mm512_set_epi64(21, 18, 15, 12, 9, 6, 3, 0); + let idx_r_c1 = _mm512_set_epi64(22, 19, 16, 13, 10, 7, 4, 1); + let idx_r_c2 = _mm512_set_epi64(23, 20, 17, 14, 11, 8, 5, 2); + + let base_out = dst as *mut i64; + _mm512_i64scatter_epi64::<8>(base_out, idx_r_c0, r_c0); + _mm512_i64scatter_epi64::<8>(base_out, idx_r_c1, r_c1); + _mm512_i64scatter_epi64::<8>(base_out, idx_r_c2, r_c2); +} + +/// Montgomery multiplication for single-limb Goldilocks (scalar). +/// +/// Computes `mont_mul(a, b) = a * b * R^{-1} mod P` where R = 2^64. +/// CIOS algorithm for N=1, identical to arkworks' `MontBackend`. +#[inline(always)] +fn mont_mul(a: u64, b: u64) -> u64 { + let full = (a as u128) * (b as u128); + let lo = full as u64; + let hi = (full >> 64) as u64; + + let k = lo.wrapping_mul(INV); + + let t = (k as u128) * (P as u128); + let t_lo = t as u64; + let t_hi = (t >> 64) as u64; + + let (_, carry) = lo.overflowing_add(t_lo); + let (mut result, carry2) = hi.overflowing_add(t_hi); + let (result2, carry3) = result.overflowing_add(carry as u64); + result = result2; + + if carry2 || carry3 || result >= P { + result = result.wrapping_sub(P); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::F64; + use ark_ff::{AdditiveGroup, UniformRand}; + use ark_std::test_rng; + + #[test] + fn test_mont_mul_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a * b; + let result = F64::from_raw(mont_mul(a.value, b.value)); + assert_eq!( + expected, result, + "mont_mul mismatch for a={:?}, b={:?}", + a, b + ); + } + } + + #[test] + fn test_mont_add_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a + b; + let result = F64::from_raw(GoldilocksAvx512::scalar_add(a.value, b.value)); + assert_eq!(expected, result); + } + } + + #[test] + fn test_mont_sub_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a - b; + let result = F64::from_raw(GoldilocksAvx512::scalar_sub(a.value, b.value)); + assert_eq!(expected, result); + } + } + + #[test] + fn test_avx512_mont_mul() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let a: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + let b: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + + let a_raw: [u64; 8] = core::array::from_fn(|i| a[i].value); + let b_raw: [u64; 8] = core::array::from_fn(|i| b[i].value); + + let a_v = unsafe { GoldilocksAvx512::load(a_raw.as_ptr()) }; + let b_v = unsafe { GoldilocksAvx512::load(b_raw.as_ptr()) }; + let r_v = GoldilocksAvx512::mul(a_v, b_v); + + let mut result = [0u64; 8]; + unsafe { GoldilocksAvx512::store(result.as_mut_ptr(), r_v) }; + + for i in 0..8 { + assert_eq!( + F64::from_raw(result[i]), + a[i] * b[i], + "lane {i} mul mismatch" + ); + } + } + } + + #[test] + fn test_avx512_add() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let a: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + let b: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + + let a_raw: [u64; 8] = core::array::from_fn(|i| a[i].value); + let b_raw: [u64; 8] = core::array::from_fn(|i| b[i].value); + + let a_v = unsafe { GoldilocksAvx512::load(a_raw.as_ptr()) }; + let b_v = unsafe { GoldilocksAvx512::load(b_raw.as_ptr()) }; + let r_v = GoldilocksAvx512::add(a_v, b_v); + + let mut result = [0u64; 8]; + unsafe { GoldilocksAvx512::store(result.as_mut_ptr(), r_v) }; + + for i in 0..8 { + assert_eq!( + F64::from_raw(result[i]), + a[i] + b[i], + "lane {i} add mismatch" + ); + } + } + } + + #[test] + fn test_avx512_sub() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let a: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + let b: [F64; 8] = core::array::from_fn(|_| F64::rand(&mut rng)); + + let a_raw: [u64; 8] = core::array::from_fn(|i| a[i].value); + let b_raw: [u64; 8] = core::array::from_fn(|i| b[i].value); + + let a_v = unsafe { GoldilocksAvx512::load(a_raw.as_ptr()) }; + let b_v = unsafe { GoldilocksAvx512::load(b_raw.as_ptr()) }; + let r_v = GoldilocksAvx512::sub(a_v, b_v); + + let mut result = [0u64; 8]; + unsafe { GoldilocksAvx512::store(result.as_mut_ptr(), r_v) }; + + for i in 0..8 { + assert_eq!( + F64::from_raw(result[i]), + a[i] - b[i], + "lane {i} sub mismatch" + ); + } + } + } + + #[test] + fn test_transmute_roundtrip() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let f = F64::rand(&mut rng); + let mont = f.value; + let back = F64::from_raw(mont); + assert_eq!(f, back, "transmute roundtrip failed"); + } + } + + #[test] + fn test_edge_cases() { + use ark_ff::Field; + let zero = F64::ZERO; + let one = F64::ONE; + let neg_one = -F64::ONE; + + // 0 * anything = 0 + assert_eq!(F64::from_raw(mont_mul(zero.value, neg_one.value)), zero); + // 1 * x = x + assert_eq!(F64::from_raw(mont_mul(one.value, neg_one.value)), neg_one); + // (-1) * (-1) = 1 + assert_eq!(F64::from_raw(mont_mul(neg_one.value, neg_one.value)), one); + } + + #[test] + fn test_avx512_edge_cases_vectorized() { + use ark_ff::Field; + let zero = F64::ZERO; + let one = F64::ONE; + let neg_one = -F64::ONE; + + // Test with all-zero, all-one, all-neg_one, and mixed lanes + let a_vals = [zero, one, neg_one, one, zero, neg_one, one, neg_one]; + let b_vals = [neg_one, neg_one, neg_one, one, zero, one, zero, zero]; + let expected: [F64; 8] = core::array::from_fn(|i| a_vals[i] * b_vals[i]); + + let a_raw: [u64; 8] = core::array::from_fn(|i| a_vals[i].value); + let b_raw: [u64; 8] = core::array::from_fn(|i| b_vals[i].value); + + let a_v = unsafe { GoldilocksAvx512::load(a_raw.as_ptr()) }; + let b_v = unsafe { GoldilocksAvx512::load(b_raw.as_ptr()) }; + let r_v = GoldilocksAvx512::mul(a_v, b_v); + + let mut result = [0u64; 8]; + unsafe { GoldilocksAvx512::store(result.as_mut_ptr(), r_v) }; + + for i in 0..8 { + assert_eq!( + F64::from_raw(result[i]), + expected[i], + "edge case lane {i} mismatch" + ); + } + } + + #[test] + fn test_ext2_scalar_mul() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + + let a = [a0.value, a1.value]; + let b = [b0.value, b1.value]; + let result = ext2_scalar_mul(a, b, w_mont); + + // Naive: c0 = a0*b0 + 7*a1*b1, c1 = a0*b1 + a1*b0 + let expected_c0 = a0 * b0 + F64::from(7u64) * a1 * b1; + let expected_c1 = a0 * b1 + a1 * b0; + + assert_eq!(F64::from_raw(result[0]), expected_c0, "ext2 c0 mismatch"); + assert_eq!(F64::from_raw(result[1]), expected_c1, "ext2 c1 mismatch"); + } + } + + #[test] + fn test_ext3_scalar_mul() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + let w = F64::from(7u64); + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let a2 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + let b2 = F64::rand(&mut rng); + + let a = [a0.value, a1.value, a2.value]; + let b = [b0.value, b1.value, b2.value]; + let result = ext3_scalar_mul(a, b, w_mont); + + // Naive schoolbook mod (X³ - w): + let expected_c0 = a0 * b0 + w * (a1 * b2 + a2 * b1); + let expected_c1 = a0 * b1 + a1 * b0 + w * a2 * b2; + let expected_c2 = a0 * b2 + a1 * b1 + a2 * b0; + + assert_eq!(F64::from_raw(result[0]), expected_c0, "ext3 c0 mismatch"); + assert_eq!(F64::from_raw(result[1]), expected_c1, "ext3 c1 mismatch"); + assert_eq!(F64::from_raw(result[2]), expected_c2, "ext3 c2 mismatch"); + } + } + + #[test] + fn test_ext2_avx512_matches_scalar() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + let w_vec = GoldilocksAvx512::splat(w_mont); + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + + // Broadcast same values across all 8 lanes + let a_v = [ + GoldilocksAvx512::splat(a0.value), + GoldilocksAvx512::splat(a1.value), + ]; + let b_v = [ + GoldilocksAvx512::splat(b0.value), + GoldilocksAvx512::splat(b1.value), + ]; + + let r_v = ext2_mul(a_v, b_v, w_vec); + + let mut r_out = [[0u64; 8]; 2]; + unsafe { + GoldilocksAvx512::store(r_out[0].as_mut_ptr(), r_v[0]); + GoldilocksAvx512::store(r_out[1].as_mut_ptr(), r_v[1]); + } + + let scalar_result = ext2_scalar_mul([a0.value, a1.value], [b0.value, b1.value], w_mont); + + for lane in 0..8 { + assert_eq!( + r_out[0][lane], scalar_result[0], + "ext2 AVX-512 c0 lane {lane} mismatch" + ); + assert_eq!( + r_out[1][lane], scalar_result[1], + "ext2 AVX-512 c1 lane {lane} mismatch" + ); + } + } + } + + #[test] + fn test_ext3_avx512_matches_scalar() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + let w_vec = GoldilocksAvx512::splat(w_mont); + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let a2 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + let b2 = F64::rand(&mut rng); + + let a_v = [ + GoldilocksAvx512::splat(a0.value), + GoldilocksAvx512::splat(a1.value), + GoldilocksAvx512::splat(a2.value), + ]; + let b_v = [ + GoldilocksAvx512::splat(b0.value), + GoldilocksAvx512::splat(b1.value), + GoldilocksAvx512::splat(b2.value), + ]; + + let r_v = ext3_mul(a_v, b_v, w_vec); + + let mut r_out = [[0u64; 8]; 3]; + unsafe { + GoldilocksAvx512::store(r_out[0].as_mut_ptr(), r_v[0]); + GoldilocksAvx512::store(r_out[1].as_mut_ptr(), r_v[1]); + GoldilocksAvx512::store(r_out[2].as_mut_ptr(), r_v[2]); + } + + let scalar_result = ext3_scalar_mul( + [a0.value, a1.value, a2.value], + [b0.value, b1.value, b2.value], + w_mont, + ); + + for lane in 0..8 { + assert_eq!( + r_out[0][lane], scalar_result[0], + "ext3 AVX-512 c0 lane {lane} mismatch" + ); + assert_eq!( + r_out[1][lane], scalar_result[1], + "ext3 AVX-512 c1 lane {lane} mismatch" + ); + assert_eq!( + r_out[2][lane], scalar_result[2], + "ext3 AVX-512 c2 lane {lane} mismatch" + ); + } + } + } + + #[test] + fn test_ext2_reduce_8pairs() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + + for _ in 0..1_000 { + // Generate 8 pairs of ext2 elements in AoS layout (32 u64s) + let src: Vec = (0..32).map(|_| F64::rand(&mut rng.value)).collect(); + let challenge = [F64::rand(&mut rng.value), F64::rand(&mut rng.value)]; + + // Reference: scalar reduce + let mut expected = vec![0u64; 16]; + for i in 0..8 { + let a = [src[4 * i], src[4 * i + 1]]; + let b = [src[4 * i + 2], src[4 * i + 3]]; + let diff = [ + GoldilocksAvx512::scalar_sub(b[0], a[0]), + GoldilocksAvx512::scalar_sub(b[1], a[1]), + ]; + let prod = ext2_scalar_mul(diff, challenge, w_mont); + expected[2 * i] = GoldilocksAvx512::scalar_add(a[0], prod[0]); + expected[2 * i + 1] = GoldilocksAvx512::scalar_add(a[1], prod[1]); + } + + // Vectorized + let mut actual = vec![0u64; 16]; + let challenge_c0 = GoldilocksAvx512::splat(challenge[0]); + let challenge_c1 = GoldilocksAvx512::splat(challenge[1]); + let w_vec = GoldilocksAvx512::splat(w_mont); + unsafe { + ext2_reduce_8pairs( + src.as_ptr(), + actual.as_mut_ptr(), + challenge_c0, + challenge_c1, + w_vec, + ); + } + + assert_eq!(expected, actual, "ext2_reduce_8pairs mismatch"); + } + } + + #[test] + fn test_ext3_reduce_8pairs() { + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + + for _ in 0..1_000 { + // Generate 8 pairs of ext3 elements in AoS layout (48 u64s) + let src: Vec = (0..48).map(|_| F64::rand(&mut rng.value)).collect(); + let challenge = [ + F64::rand(&mut rng.value), + F64::rand(&mut rng.value), + F64::rand(&mut rng.value), + ]; + + // Reference: scalar reduce + let mut expected = vec![0u64; 24]; + for i in 0..8 { + let a = [src[6 * i], src[6 * i + 1], src[6 * i + 2]]; + let b = [src[6 * i + 3], src[6 * i + 4], src[6 * i + 5]]; + let diff = [ + GoldilocksAvx512::scalar_sub(b[0], a[0]), + GoldilocksAvx512::scalar_sub(b[1], a[1]), + GoldilocksAvx512::scalar_sub(b[2], a[2]), + ]; + let prod = ext3_scalar_mul(diff, challenge, w_mont); + expected[3 * i] = GoldilocksAvx512::scalar_add(a[0], prod[0]); + expected[3 * i + 1] = GoldilocksAvx512::scalar_add(a[1], prod[1]); + expected[3 * i + 2] = GoldilocksAvx512::scalar_add(a[2], prod[2]); + } + + // Vectorized + let mut actual = vec![0u64; 24]; + let challenge_v = [ + GoldilocksAvx512::splat(challenge[0]), + GoldilocksAvx512::splat(challenge[1]), + GoldilocksAvx512::splat(challenge[2]), + ]; + let w_vec = GoldilocksAvx512::splat(w_mont); + unsafe { + ext3_reduce_8pairs(src.as_ptr(), actual.as_mut_ptr(), challenge_v, w_vec); + } + + assert_eq!(expected, actual, "ext3_reduce_8pairs mismatch"); + } + } +} diff --git a/src/simd_fields/goldilocks/mod.rs b/src/simd_fields/goldilocks/mod.rs new file mode 100644 index 00000000..73e45019 --- /dev/null +++ b/src/simd_fields/goldilocks/mod.rs @@ -0,0 +1,30 @@ +//! Goldilocks field (p = 2^64 - 2^32 + 1) SIMD backends. + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub mod neon; + +#[cfg(all( + feature = "simd", + target_arch = "x86_64", + target_feature = "avx512ifma" +))] +pub mod avx512; + +/// Goldilocks NEON backend (aarch64). +/// +/// Operates on Montgomery-form values as stored by arkworks (`SmallFp.value` +/// or `Fp64.0.0[0]`) — zero-cost transmute from `&[Field]` to `&[u64]`. +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +#[allow(unused_imports)] +pub use neon::GoldilocksNeon; + +/// Goldilocks AVX-512 IFMA backend (x86_64). +/// +/// Same Montgomery-form transmute as the NEON backend, but with true 8-wide +/// vectorized multiplication via 52-bit IFMA decomposition. +#[cfg(all( + feature = "simd", + target_arch = "x86_64", + target_feature = "avx512ifma" +))] +pub use avx512::GoldilocksAvx512; diff --git a/src/simd_fields/goldilocks/neon.rs b/src/simd_fields/goldilocks/neon.rs new file mode 100644 index 00000000..881ba6a3 --- /dev/null +++ b/src/simd_fields/goldilocks/neon.rs @@ -0,0 +1,509 @@ +#![allow(dead_code)] +//! Montgomery-form Goldilocks NEON backend. +//! +//! Operates directly on Montgomery-form values (as stored by arkworks `Fp64`), +//! enabling zero-cost `transmute` from `&[F64]` to `&[u64]`. +//! +//! Implements the same CIOS Montgomery reduction as arkworks' `MontBackend` +//! for `N=1`, so results are bit-identical. + +use core::arch::aarch64::*; + +use super::super::SimdBaseField; + +/// Goldilocks modulus: P = 2^64 - 2^32 + 1. +const P: u64 = 0xFFFF_FFFF_0000_0001; + +/// Montgomery constant: INV = -P^{-1} mod 2^64. +const INV: u64 = 0xFFFF_FFFE_FFFF_FFFF; + +/// ε = 2^64 mod P = 2^32 - 1 (used for add/sub overflow correction). +const EPSILON: u64 = 0xFFFF_FFFF; + +/// Montgomery ONE = R mod P = 2^64 mod P = EPSILON. +const MONT_ONE: u64 = EPSILON; + +/// Montgomery ZERO = 0 (same in both domains). +const MONT_ZERO: u64 = 0; + +#[derive(Copy, Clone)] +pub struct GoldilocksNeon; + +impl SimdBaseField for GoldilocksNeon { + type Scalar = u64; + type Packed = uint64x2_t; + const LANES: usize = 2; + const MODULUS: u64 = P; + const ZERO: u64 = MONT_ZERO; + const ONE: u64 = MONT_ONE; + + #[inline(always)] + fn splat(val: u64) -> uint64x2_t { + unsafe { vdupq_n_u64(val) } + } + + #[inline(always)] + unsafe fn load(ptr: *const u64) -> uint64x2_t { + unsafe { vld1q_u64(ptr) } + } + + #[inline(always)] + unsafe fn store(ptr: *mut u64, v: uint64x2_t) { + unsafe { vst1q_u64(ptr, v) } + } + + #[inline(always)] + unsafe fn load_deinterleaved(ptr: *const u64) -> (uint64x2_t, uint64x2_t) { + let pair = unsafe { vld2q_u64(ptr) }; + (pair.0, pair.1) + } + + // Add/sub are identical in canonical and Montgomery domain. + // mont(a) + mont(b) = mont(a + b), same wrapping/reduction logic. + + #[inline(always)] + fn add(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + unsafe { + let sum = vaddq_u64(a, b); + let p_vec = vdupq_n_u64(P); + let eps_vec = vdupq_n_u64(EPSILON); + let carry = vcltq_u64(sum, a); + let geq_p = vcgeq_u64(sum, p_vec); + let sub_p = vsubq_u64(sum, p_vec); + let no_carry_result = vbslq_u64(geq_p, sub_p, sum); + let carry_result = vaddq_u64(sum, eps_vec); + vbslq_u64(carry, carry_result, no_carry_result) + } + } + + #[inline(always)] + fn sub(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + unsafe { + let diff = vsubq_u64(a, b); + let p_vec = vdupq_n_u64(P); + let borrow = vcltq_u64(a, b); + let corrected = vaddq_u64(diff, p_vec); + vbslq_u64(borrow, corrected, diff) + } + } + + #[inline(always)] + fn mul(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + // Per-lane Montgomery multiplication (CIOS for N=1). + // + // NEON has no 64×64→128 multiply instruction. We tried vectorizing + // via four `vmull_u32` partial products (see `mont_mul_pair` below, + // kept for testing/reference), but it was ~1.5× SLOWER across all + // input sizes on Apple Silicon — the M-series scalar integer pipeline + // is fast enough that `(a as u128) * (b as u128)` (compiled to + // MUL+UMULH, 2 instructions) beats ~14+ NEON instructions for the + // vectorized equivalent. On other ARM cores with narrower scalar + // pipes (Graviton, Neoverse, older Cortex-A) the vectorized path + // may still win; swap in `mont_mul_pair` there if benched as such. + unsafe { + let a0 = vgetq_lane_u64(a, 0); + let a1 = vgetq_lane_u64(a, 1); + let b0 = vgetq_lane_u64(b, 0); + let b1 = vgetq_lane_u64(b, 1); + + let r0 = mont_mul(a0, b0); + let r1 = mont_mul(a1, b1); + + vcombine_u64(vcreate_u64(r0), vcreate_u64(r1)) + } + } + + #[inline(always)] + fn scalar_add(a: u64, b: u64) -> u64 { + let (sum, carry) = a.overflowing_add(b); + if carry { + sum + EPSILON + } else if sum >= P { + sum - P + } else { + sum + } + } + + #[inline(always)] + fn scalar_sub(a: u64, b: u64) -> u64 { + if a >= b { + a - b + } else { + a.wrapping_sub(b).wrapping_add(P) + } + } + + #[inline(always)] + fn scalar_mul(a: u64, b: u64) -> u64 { + mont_mul(a, b) + } +} + +/// Montgomery multiplication for single-limb Goldilocks. +/// +/// Computes `mont_mul(a, b) = a * b * R^{-1} mod P` where R = 2^64. +/// This is the CIOS algorithm for N=1, identical to arkworks' `MontBackend`. +/// +/// full = a * b (128-bit) +/// lo = full mod 2^64 +/// hi = full >> 64 +/// k = lo * INV mod 2^64 (INV = -P^{-1} mod 2^64) +/// t = k * P (128-bit) +/// result = (full + t) >> 64 (fits in 64 bits + carry) +/// if result >= P: result -= P +#[inline(always)] +fn mont_mul(a: u64, b: u64) -> u64 { + let full = (a as u128) * (b as u128); + let lo = full as u64; + let hi = (full >> 64) as u64; + + // k = lo * INV mod 2^64 + let k = lo.wrapping_mul(INV); + + // t = k * P (128-bit) + let t = (k as u128) * (P as u128); + let t_lo = t as u64; + let t_hi = (t >> 64) as u64; + + // (full + t) >> 64 = hi + t_hi + carry_from(lo + t_lo) + let (_, carry) = lo.overflowing_add(t_lo); + let (mut result, carry2) = hi.overflowing_add(t_hi); + let (result2, carry3) = result.overflowing_add(carry as u64); + result = result2; + + // Handle carry: carry2 || carry3 can happen since Goldilocks has no spare bit + if carry2 || carry3 || result >= P { + result = result.wrapping_sub(P); + } + + result +} + +// ── Extension field SIMD multiply functions ───────────────────────────────── +// +// These are free functions rather than trait impls because the nonresidue +// is a runtime value (extracted from the arkworks extension field config +// during dispatch). The SimdExtField trait on mod.rs defines the interface; +// these functions implement the Karatsuba formulas for degree 2 and 3. + +/// Degree-2 Karatsuba: (a0 + a1·X)(b0 + b1·X) mod (X² - w) +/// 3 base muls + 1 mul-by-w + adds. +#[inline(always)] +pub fn ext2_mul(a: [uint64x2_t; 2], b: [uint64x2_t; 2], w: uint64x2_t) -> [uint64x2_t; 2] { + let v0 = GoldilocksNeon::mul(a[0], b[0]); + let v1 = GoldilocksNeon::mul(a[1], b[1]); + let c0 = GoldilocksNeon::add(v0, GoldilocksNeon::mul(w, v1)); + let a_sum = GoldilocksNeon::add(a[0], a[1]); + let b_sum = GoldilocksNeon::add(b[0], b[1]); + let c1 = GoldilocksNeon::sub( + GoldilocksNeon::sub(GoldilocksNeon::mul(a_sum, b_sum), v0), + v1, + ); + [c0, c1] +} + +/// Degree-2 Karatsuba (scalar version for tail processing). +#[inline(always)] +pub fn ext2_scalar_mul(a: [u64; 2], b: [u64; 2], w: u64) -> [u64; 2] { + let v0 = mont_mul(a[0], b[0]); + let v1 = mont_mul(a[1], b[1]); + let c0 = GoldilocksNeon::scalar_add(v0, mont_mul(w, v1)); + let a_sum = GoldilocksNeon::scalar_add(a[0], a[1]); + let b_sum = GoldilocksNeon::scalar_add(b[0], b[1]); + let c1 = GoldilocksNeon::scalar_sub(GoldilocksNeon::scalar_sub(mont_mul(a_sum, b_sum), v0), v1); + [c0, c1] +} + +/// Degree-3 Karatsuba: (a0 + a1·X + a2·X²)(b0 + b1·X + b2·X²) mod (X³ - w) +/// 6 base muls + 2 mul-by-w + adds. +#[inline(always)] +pub fn ext3_mul(a: [uint64x2_t; 3], b: [uint64x2_t; 3], w: uint64x2_t) -> [uint64x2_t; 3] { + let ad = GoldilocksNeon::mul(a[0], b[0]); + let be = GoldilocksNeon::mul(a[1], b[1]); + let cf = GoldilocksNeon::mul(a[2], b[2]); + + let x = GoldilocksNeon::sub( + GoldilocksNeon::sub( + GoldilocksNeon::mul( + GoldilocksNeon::add(a[1], a[2]), + GoldilocksNeon::add(b[1], b[2]), + ), + be, + ), + cf, + ); + let y = GoldilocksNeon::sub( + GoldilocksNeon::sub( + GoldilocksNeon::mul( + GoldilocksNeon::add(a[0], a[1]), + GoldilocksNeon::add(b[0], b[1]), + ), + ad, + ), + be, + ); + let z = GoldilocksNeon::add( + GoldilocksNeon::sub( + GoldilocksNeon::sub( + GoldilocksNeon::mul( + GoldilocksNeon::add(a[0], a[2]), + GoldilocksNeon::add(b[0], b[2]), + ), + ad, + ), + cf, + ), + be, + ); + + [ + GoldilocksNeon::add(ad, GoldilocksNeon::mul(w, x)), + GoldilocksNeon::add(y, GoldilocksNeon::mul(w, cf)), + z, + ] +} + +/// Degree-3 Karatsuba (scalar version). +#[inline(always)] +pub fn ext3_scalar_mul(a: [u64; 3], b: [u64; 3], w: u64) -> [u64; 3] { + let ad = mont_mul(a[0], b[0]); + let be = mont_mul(a[1], b[1]); + let cf = mont_mul(a[2], b[2]); + + let x = GoldilocksNeon::scalar_sub( + GoldilocksNeon::scalar_sub( + mont_mul( + GoldilocksNeon::scalar_add(a[1], a[2]), + GoldilocksNeon::scalar_add(b[1], b[2]), + ), + be, + ), + cf, + ); + let y = GoldilocksNeon::scalar_sub( + GoldilocksNeon::scalar_sub( + mont_mul( + GoldilocksNeon::scalar_add(a[0], a[1]), + GoldilocksNeon::scalar_add(b[0], b[1]), + ), + ad, + ), + be, + ); + let z = GoldilocksNeon::scalar_add( + GoldilocksNeon::scalar_sub( + GoldilocksNeon::scalar_sub( + mont_mul( + GoldilocksNeon::scalar_add(a[0], a[2]), + GoldilocksNeon::scalar_add(b[0], b[2]), + ), + ad, + ), + cf, + ), + be, + ); + + [ + GoldilocksNeon::scalar_add(ad, mont_mul(w, x)), + GoldilocksNeon::scalar_add(y, mont_mul(w, cf)), + z, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::F64; + use ark_ff::{AdditiveGroup, UniformRand}; + use ark_std::test_rng; + + #[test] + fn test_mont_mul_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a * b; + let result = F64::from_raw(mont_mul(a.value, b.value)); + assert_eq!( + expected, result, + "mont_mul mismatch for a={:?}, b={:?}", + a, b + ); + } + } + + #[test] + fn test_mont_add_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a + b; + let result = F64::from_raw(GoldilocksNeon::scalar_add(a.value, b.value)); + assert_eq!(expected, result); + } + } + + #[test] + fn test_mont_sub_matches_arkworks() { + let mut rng = test_rng(); + for _ in 0..100_000 { + let a = F64::rand(&mut rng); + let b = F64::rand(&mut rng); + let expected = a - b; + let result = F64::from_raw(GoldilocksNeon::scalar_sub(a.value, b.value)); + assert_eq!(expected, result); + } + } + + #[test] + fn test_neon_mont_mul() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + + let a_raw = [a0.value, a1.value]; + let b_raw = [b0.value, b1.value]; + + let a_v = unsafe { GoldilocksNeon::load(a_raw.as_ptr()) }; + let b_v = unsafe { GoldilocksNeon::load(b_raw.as_ptr()) }; + let r_v = GoldilocksNeon::mul(a_v, b_v); + + let mut result = [0u64; 2]; + unsafe { GoldilocksNeon::store(result.as_mut_ptr(), r_v) }; + + assert_eq!(F64::from_raw(result[0]), a0 * b0); + assert_eq!(F64::from_raw(result[1]), a1 * b1); + } + } + + #[test] + fn test_transmute_roundtrip() { + let mut rng = test_rng(); + for _ in 0..10_000 { + let f = F64::rand(&mut rng); + let mont = f.value; + let back = F64::from_raw(mont); + assert_eq!(f, back, "transmute roundtrip failed"); + } + } + + #[test] + fn test_edge_cases() { + use ark_ff::Field; + let zero = F64::ZERO; + let one = F64::ONE; + let neg_one = -F64::ONE; + + // 0 * anything = 0 + assert_eq!(F64::from_raw(mont_mul(zero.value, neg_one.value)), zero); + // 1 * x = x + assert_eq!(F64::from_raw(mont_mul(one.value, neg_one.value)), neg_one); + // (-1) * (-1) = 1 + assert_eq!(F64::from_raw(mont_mul(neg_one.value, neg_one.value)), one); + } + + #[test] + fn test_ext2_scalar_mul() { + // Test degree-2 extension multiply against naive computation. + // Using nonresidue w = 7 (in Montgomery form). + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + + let a = [a0.value, a1.value]; + let b = [b0.value, b1.value]; + let result = ext2_scalar_mul(a, b, w_mont); + + // Naive: c0 = a0*b0 + 7*a1*b1, c1 = a0*b1 + a1*b0 + let expected_c0 = a0 * b0 + F64::from(7u64) * a1 * b1; + let expected_c1 = a0 * b1 + a1 * b0; + + assert_eq!(F64::from_raw(result[0]), expected_c0, "ext2 c0 mismatch"); + assert_eq!(F64::from_raw(result[1]), expected_c1, "ext2 c1 mismatch"); + } + } + + #[test] + fn test_ext3_scalar_mul() { + // Test degree-3 extension multiply against naive schoolbook. + // Using nonresidue w = 7 (in Montgomery form). + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + let w = F64::from(7u64); + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let a2 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + let b2 = F64::rand(&mut rng); + + let a = [a0.value, a1.value, a2.value]; + let b = [b0.value, b1.value, b2.value]; + let result = ext3_scalar_mul(a, b, w_mont); + + // Naive schoolbook mod (X³ - w): + // c0 = a0*b0 + w*(a1*b2 + a2*b1) + // c1 = a0*b1 + a1*b0 + w*a2*b2 + // c2 = a0*b2 + a1*b1 + a2*b0 + let expected_c0 = a0 * b0 + w * (a1 * b2 + a2 * b1); + let expected_c1 = a0 * b1 + a1 * b0 + w * a2 * b2; + let expected_c2 = a0 * b2 + a1 * b1 + a2 * b0; + + assert_eq!(F64::from_raw(result[0]), expected_c0, "ext3 c0 mismatch"); + assert_eq!(F64::from_raw(result[1]), expected_c1, "ext3 c1 mismatch"); + assert_eq!(F64::from_raw(result[2]), expected_c2, "ext3 c2 mismatch"); + } + } + + #[test] + fn test_ext2_neon_matches_scalar() { + // Verify NEON ext2_mul matches ext2_scalar_mul. + let mut rng = test_rng(); + let w_mont = F64::from(7u64).value; + let w_vec = GoldilocksNeon::splat(w_mont); + + for _ in 0..10_000 { + let a0 = F64::rand(&mut rng); + let a1 = F64::rand(&mut rng); + let b0 = F64::rand(&mut rng); + let b1 = F64::rand(&mut rng); + + let a_raw = [[a0.value, a0.value], [a1.value, a1.value]]; + let b_raw = [[b0.value, b0.value], [b1.value, b1.value]]; + + let a_v = [unsafe { GoldilocksNeon::load(a_raw[0].as_ptr()) }, unsafe { + GoldilocksNeon::load(a_raw[1].as_ptr()) + }]; + let b_v = [unsafe { GoldilocksNeon::load(b_raw[0].as_ptr()) }, unsafe { + GoldilocksNeon::load(b_raw[1].as_ptr()) + }]; + + let r_v = ext2_mul(a_v, b_v, w_vec); + + let mut r_out = [[0u64; 2]; 2]; + unsafe { + GoldilocksNeon::store(r_out[0].as_mut_ptr(), r_v[0]); + GoldilocksNeon::store(r_out[1].as_mut_ptr(), r_v[1]); + } + + let scalar_result = ext2_scalar_mul([a0.value, a1.value], [b0.value, b1.value], w_mont); + + assert_eq!(r_out[0][0], scalar_result[0], "ext2 NEON c0 mismatch"); + assert_eq!(r_out[1][0], scalar_result[1], "ext2 NEON c1 mismatch"); + } + } +} diff --git a/src/simd_fields/mod.rs b/src/simd_fields/mod.rs new file mode 100644 index 00000000..a05374c9 --- /dev/null +++ b/src/simd_fields/mod.rs @@ -0,0 +1,142 @@ +#![allow(dead_code)] +//! SIMD-vectorized field arithmetic using native intrinsics. +//! +//! Each base field provides platform-specific implementations of add, sub, mul +//! operating on packed SIMD vectors. Currently supports: +//! +//! - **Goldilocks** (p = 2^64 − 2^32 + 1) via NEON on aarch64, AVX-512 IFMA on x86_64. +//! +//! # Extension fields +//! +//! The [`SimdExtField`] trait extends [`SimdBaseField`] with multiplication +//! formulas for algebraic extensions (degree 2, 3, 4, ...). Extension field +//! elements are represented as `d` consecutive base field scalars in memory. +//! Addition is component-wise (uses base field SIMD directly). Multiplication +//! uses Karatsuba or schoolbook formulas with base field SIMD operations. + +pub mod goldilocks; + +/// Platform-agnostic packed field operations. +/// +/// Each ISA backend (NEON, AVX2, AVX-512) provides its own implementation +/// with the appropriate packed vector type. +/// +/// # Safety +/// +/// All values stored in `Packed` vectors must be valid field elements +/// (i.e., in `0..P`). The arithmetic functions maintain this invariant +/// when given valid inputs. +pub trait SimdBaseField: Copy + Send + Sync + Sized + 'static { + /// Scalar representation (u32 for 31-bit fields, u64 for Goldilocks). + type Scalar: Copy + Send + Sync + Default + PartialEq + core::fmt::Debug + 'static; + + /// The packed SIMD vector type (e.g., `uint64x2_t`, `__m256i`). + type Packed: Copy; + + /// Number of scalar lanes in one `Packed` vector. + const LANES: usize; + + /// The field modulus as a scalar. + const MODULUS: Self::Scalar; + + /// Zero element. + const ZERO: Self::Scalar; + + /// One element. + const ONE: Self::Scalar; + + /// Broadcast a scalar to all lanes. + fn splat(val: Self::Scalar) -> Self::Packed; + + /// Load a packed vector from a pointer (must be aligned to `Packed`). + /// + /// # Safety + /// + /// `ptr` must point to at least `LANES` valid `Scalar` values. + unsafe fn load(ptr: *const Self::Scalar) -> Self::Packed; + + /// Store a packed vector to a pointer. + /// + /// # Safety + /// + /// `ptr` must point to writable memory for at least `LANES` `Scalar` values. + unsafe fn store(ptr: *mut Self::Scalar, v: Self::Packed); + + /// Packed modular addition: `(a + b) mod P`. + fn add(a: Self::Packed, b: Self::Packed) -> Self::Packed; + + /// Packed modular subtraction: `(a - b) mod P`. + fn sub(a: Self::Packed, b: Self::Packed) -> Self::Packed; + + /// Packed modular multiplication: `(a * b) mod P`. + fn mul(a: Self::Packed, b: Self::Packed) -> Self::Packed; + + /// Scalar modular addition (non-vectorized, for reductions). + fn scalar_add(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar; + + /// Scalar modular subtraction (non-vectorized, for reductions). + fn scalar_sub(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar; + + /// Scalar modular multiplication (non-vectorized, for reductions). + fn scalar_mul(a: Self::Scalar, b: Self::Scalar) -> Self::Scalar; + + /// Wrapping add without modular reduction — just raw integer addition. + /// + /// Callers must track carries separately and finalize with `reduce_carry`. + /// Backends should override for performance; default falls back to `add`. + #[inline(always)] + fn add_wrapping(a: Self::Packed, b: Self::Packed) -> Self::Packed { + Self::add(a, b) + } + + /// Detect carries from a wrapping add: returns a packed vector with `1` in + /// lanes where `sum < a` (unsigned overflow) and `0` elsewhere. + /// + /// Default returns zero (no carries tracked — consistent with `add` default). + #[inline(always)] + fn carry_mask(_sum: Self::Packed, _a_before: Self::Packed) -> Self::Packed { + Self::splat(Self::ZERO) + } + + /// Correct a wrapping accumulator given the carry count per lane. + /// + /// For Goldilocks: each carry represents 2^64 ≡ EPSILON (mod P), + /// so result = sum + carry_count * EPSILON (mod P). + /// + /// Default is identity (assumes `add_wrapping` already reduced). + #[inline(always)] + fn reduce_carry(sum: Self::Packed, _carry_count: Self::Packed) -> Self::Packed { + sum + } + + /// Load `2 * LANES` scalars from interleaved pairs and deinterleave: + /// `[a0, b0, a1, b1, ..., a_{L-1}, b_{L-1}]` → `(evens, odds)`. + /// + /// Default: scalar deinterleave through stack buffers. + /// Backends with native shuffle (e.g. AVX-512 `vpermutex2var`) should override. + /// + /// # Safety + /// + /// `ptr` must point to at least `2 * LANES` valid `Scalar` values. + #[inline(always)] + unsafe fn load_deinterleaved(ptr: *const Self::Scalar) -> (Self::Packed, Self::Packed) { + assert!( + Self::LANES <= 32, + "LANES={} exceeds max supported (32)", + Self::LANES + ); + let mut evens = [Self::ZERO; 32]; + let mut odds = [Self::ZERO; 32]; + for j in 0..Self::LANES { + evens[j] = *ptr.add(2 * j); + odds[j] = *ptr.add(2 * j + 1); + } + (Self::load(evens.as_ptr()), Self::load(odds.as_ptr())) + } +} + +// Extension field SIMD multiplication is implemented as free functions +// in each backend module (e.g., `goldilocks::neon::ext2_mul`) rather than +// as a trait, because the nonresidue is a runtime value extracted from the +// arkworks extension field config during dispatch. See `ext2_mul`, `ext3_mul` +// in the backend modules. diff --git a/src/simd_sumcheck/dispatch.rs b/src/simd_sumcheck/dispatch.rs new file mode 100644 index 00000000..ea7345ca --- /dev/null +++ b/src/simd_sumcheck/dispatch.rs @@ -0,0 +1,287 @@ +#![allow(dead_code)] +//! SIMD auto-dispatch for the multilinear sumcheck protocol. +//! +//! When `F` is a Goldilocks field (p = 2^64 − 2^32 + 1) stored as a single +//! `u64` in Montgomery form, the sumcheck is transparently routed to a +//! SIMD-accelerated backend: +//! +//! - **aarch64**: NEON backend (2-wide, scalar mul fallback) +//! - **x86_64 + AVX-512 IFMA**: AVX-512 backend (8-wide, true IFMA mul) +//! +//! Detection uses [`SumcheckField::_simd_field_config()`] — the arkworks +//! blanket impl returns the actual modulus, non-arkworks fields return +//! `None` by default (no SIMD). After monomorphization the check is +//! constant-folded by LLVM, so the dead branch is eliminated entirely. +//! +//! # Safety +//! +//! This module contains **no `unsafe` code**. All field ↔ `u64` +//! reinterpretation is delegated to the safe `SumcheckField` trait methods +//! (`_to_raw_u64`, `_from_raw_u64`, `_as_u64_slice`, `_as_u64_slice_mut`, +//! `_from_u64_components`), whose implementations centralize the necessary +//! `unsafe` in the arkworks blanket impl with full SAFETY documentation. + +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +use crate::field::SumcheckField; + +/// Goldilocks modulus: p = 2^64 − 2^32 + 1. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +const GOLDILOCKS_P: u64 = 0xFFFF_FFFF_0000_0001; + +/// Returns `true` when `F` is a Goldilocks prime field stored as a +/// single `u64` in Montgomery form. +/// +/// Uses [`SumcheckField::_simd_field_config()`] for detection. +/// After monomorphization every operand is a compile-time constant, +/// so LLVM folds the entire function to `true` or `false`. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +#[inline(always)] +fn is_goldilocks() -> bool { + if F::extension_degree() != 1 { + return false; + } + match F::_simd_field_config() { + Some(cfg) => cfg.modulus == GOLDILOCKS_P && cfg.element_bytes == 8, + None => false, + } +} + +/// Returns `true` when `F` has Goldilocks as its base prime field, +/// regardless of extension degree. For degree-1 this is the same as +/// `is_goldilocks`. For degree 2, 3, etc., the element is `d` consecutive +/// `u64` values in Montgomery form. +/// +/// After monomorphization, fully constant-folded by LLVM. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +#[inline(always)] +fn is_goldilocks_based() -> bool { + match F::_simd_field_config() { + Some(cfg) => { + if cfg.modulus != GOLDILOCKS_P || cfg.element_bytes != 8 { + return false; + } + let d = F::extension_degree() as usize; + core::mem::size_of::() == d * 8 + } + None => false, + } +} + +// ─── Standalone SIMD reduce (Field-level API) ────────────────────────────── + +/// SIMD-accelerated pairwise reduce on a `Vec`. +/// +/// If `F` is a recognised Goldilocks field, runs the SIMD reduce in-place +/// and truncates the vector. Otherwise returns `false` and the caller +/// should fall back to the generic path. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +pub(crate) fn try_simd_reduce(evals: &mut Vec, challenge: F) -> bool { + if !is_goldilocks::() { + return false; + } + + #[cfg(target_arch = "aarch64")] + type Backend = crate::simd_fields::goldilocks::neon::GoldilocksNeon; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + type Backend = crate::simd_fields::goldilocks::avx512::GoldilocksAvx512; + + use crate::simd_sumcheck::reduce::reduce_in_place; + + let buf: &mut [u64] = F::_as_u64_slice_mut(evals.as_mut_slice()); + let chg: u64 = challenge._to_raw_u64(); + let new_len = reduce_in_place::(buf, chg); + evals.truncate(new_len); + true +} + +/// SIMD-accelerated MSB (half-split) reduce on a `Vec`. +/// +/// Like [`try_simd_reduce`] but uses the half-split layout: +/// `new[k] = v[k] + challenge * (v[k + L/2] − v[k])`. +/// Returns `false` for non-Goldilocks fields. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +pub(crate) fn try_simd_reduce_msb(evals: &mut Vec, challenge: F) -> bool { + if !is_goldilocks::() { + return false; + } + + #[cfg(target_arch = "aarch64")] + type Backend = crate::simd_fields::goldilocks::neon::GoldilocksNeon; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + type Backend = crate::simd_fields::goldilocks::avx512::GoldilocksAvx512; + + use crate::simd_sumcheck::reduce::reduce_msb_in_place; + + let buf: &mut [u64] = F::_as_u64_slice_mut(evals.as_mut_slice()); + let chg: u64 = challenge._to_raw_u64(); + let new_len = reduce_msb_in_place::(buf, chg); + evals.truncate(new_len); + true +} + +// ─── SIMD degree-1 evaluate for coefficient sumcheck ──────────────────────── + +/// Fused SIMD reduce + degree-1 evaluate. +/// +/// Reduces `pw` in-place and returns `[s0, s1 - s0]` for the next round, +/// computed in a single data pass via `reduce_and_evaluate`. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +pub(crate) fn try_simd_fused_reduce_evaluate_degree1( + pw: &mut Vec, + challenge: F, +) -> Option> { + if !is_goldilocks::() { + return None; + } + + #[cfg(target_arch = "aarch64")] + type Backend = crate::simd_fields::goldilocks::neon::GoldilocksNeon; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + type Backend = crate::simd_fields::goldilocks::avx512::GoldilocksAvx512; + + use crate::simd_sumcheck::reduce::reduce_and_evaluate; + + let buf: &mut [u64] = F::_as_u64_slice_mut(pw.as_mut_slice()); + let chg: u64 = challenge._to_raw_u64(); + let (s0_raw, s1_raw, new_len) = reduce_and_evaluate::(buf, chg); + pw.truncate(new_len); + + let s0: F = F::_from_raw_u64(s0_raw); + let s1: F = F::_from_raw_u64(s1_raw); + Some(vec![s0, s1 - s0]) +} + +// ─── Extension field evaluate dispatch ────────────────────────────────────── + +/// SIMD-accelerated pairwise evaluate for extension field elements. +/// +/// Returns `Some((sum_even, sum_odd))` as extension field elements if +/// `EF` is a Goldilocks extension. Returns `None` otherwise. +/// +/// The evaluate is pure addition (component-wise), so SIMD wins regardless +/// of extension degree. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +pub(crate) fn try_simd_ext_evaluate(evals: &[EF]) -> Option<(EF, EF)> { + if !is_goldilocks_based::() { + return None; + } + + #[cfg(target_arch = "aarch64")] + type Backend = crate::simd_fields::goldilocks::neon::GoldilocksNeon; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + type Backend = crate::simd_fields::goldilocks::avx512::GoldilocksAvx512; + + let d = EF::extension_degree() as usize; + + if d == 1 { + // Base field — use the optimized base evaluate + let buf: &[u64] = EF::_as_u64_slice(evals); + let (s0, s1) = crate::simd_sumcheck::evaluate::evaluate_parallel::(buf); + return Some((EF::_from_raw_u64(s0), EF::_from_raw_u64(s1))); + } + + // Extension field: view as flat u64 buffer and run ext_evaluate + let buf: &[u64] = EF::_as_u64_slice(evals); + + let (even_comps, odd_comps) = + crate::simd_sumcheck::evaluate::ext_evaluate_parallel::(buf, d); + + let even: EF = EF::_from_u64_components(&even_comps); + let odd: EF = EF::_from_u64_components(&odd_comps); + + Some((even, odd)) +} + +/// SIMD-accelerated degree-1 pairwise evaluate: returns `[s0, s1 - s0]`. +/// +/// This is the coefficient sumcheck fast path for `degree() == 1` with a single +/// pairwise table and no tablewise tables — equivalent to the multilinear +/// `evaluate_parallel` kernel. +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +pub(crate) fn try_simd_evaluate_degree1(pw: &[F]) -> Option> { + if !is_goldilocks::() { + return None; + } + + #[cfg(target_arch = "aarch64")] + type Backend = crate::simd_fields::goldilocks::neon::GoldilocksNeon; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + type Backend = crate::simd_fields::goldilocks::avx512::GoldilocksAvx512; + + use crate::simd_sumcheck::evaluate::evaluate_parallel; + + let buf: &[u64] = F::_as_u64_slice(pw); + let (s0_raw, s1_raw) = evaluate_parallel::(buf); + let s0: F = F::_from_raw_u64(s0_raw); + let s1: F = F::_from_raw_u64(s1_raw); + Some(vec![s0, s1 - s0]) +} + +// ─── Public helpers ──────────────────────────────────────────────────────── + +/// Check if `F` is a Goldilocks prime field (degree 1, size 8, matching modulus). +#[cfg(all( + feature = "simd", + any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + ) +))] +#[inline(always)] +pub fn is_goldilocks_pub() -> bool { + is_goldilocks::() +} diff --git a/src/simd_sumcheck/evaluate.rs b/src/simd_sumcheck/evaluate.rs new file mode 100644 index 00000000..af98109b --- /dev/null +++ b/src/simd_sumcheck/evaluate.rs @@ -0,0 +1,691 @@ +#![allow(dead_code)] +//! SIMD-vectorized pairwise evaluation: computes (sum_even, sum_odd). +//! +//! Uses an 8-accumulator unroll for instruction-level parallelism, +//! which is the sweet spot on NEON (saturates the register file without +//! spilling — see "Proof Systems Engineering" for benchmarking methodology). + +use crate::simd_fields::SimdBaseField; + +/// SIMD-vectorized pairwise evaluate. +/// +/// Given `src` = `[f(0), f(1), f(2), f(3), ...]`, computes: +/// sum_even = f(0) + f(2) + f(4) + ... +/// sum_odd = f(1) + f(3) + f(5) + ... +/// +/// Returns `(sum_even, sum_odd)`. +/// +/// # Panics +/// +/// Panics if `src.len()` is not a multiple of `8 * F::LANES` (the unroll factor). +pub fn evaluate(src: &[F::Scalar]) -> (F::Scalar, F::Scalar) { + let lanes = F::LANES; + let step = 8 * lanes; + assert!( + src.len() % step == 0 || src.is_empty(), + "src.len() ({}) must be a multiple of {} (8 * LANES)", + src.len(), + step + ); + + let zero = F::splat(F::ZERO); + let mut acc0 = zero; + let mut acc1 = zero; + let mut acc2 = zero; + let mut acc3 = zero; + let mut acc4 = zero; + let mut acc5 = zero; + let mut acc6 = zero; + let mut acc7 = zero; + + let ptr = src.as_ptr(); + let mut i = 0; + + while i < src.len() { + unsafe { + acc0 = F::add(acc0, F::load(ptr.add(i))); + acc1 = F::add(acc1, F::load(ptr.add(i + lanes))); + acc2 = F::add(acc2, F::load(ptr.add(i + 2 * lanes))); + acc3 = F::add(acc3, F::load(ptr.add(i + 3 * lanes))); + acc4 = F::add(acc4, F::load(ptr.add(i + 4 * lanes))); + acc5 = F::add(acc5, F::load(ptr.add(i + 5 * lanes))); + acc6 = F::add(acc6, F::load(ptr.add(i + 6 * lanes))); + acc7 = F::add(acc7, F::load(ptr.add(i + 7 * lanes))); + } + i += step; + } + + // Combine accumulators in a tree to keep ILP. + let total = F::add( + F::add(F::add(acc0, acc1), F::add(acc2, acc3)), + F::add(F::add(acc4, acc5), F::add(acc6, acc7)), + ); + + // Extract lanes and sum even/odd groups. + let mut lanes_buf = [F::ZERO; 32]; + debug_assert!(F::LANES <= 32); + unsafe { F::store(lanes_buf.as_mut_ptr(), total) }; + + let mut even_sum = F::ZERO; + let mut odd_sum = F::ZERO; + for (j, &val) in lanes_buf.iter().enumerate().take(F::LANES) { + if j % 2 == 0 { + even_sum = F::scalar_add(even_sum, val); + } else { + odd_sum = F::scalar_add(odd_sum, val); + } + } + + (even_sum, odd_sum) +} + +/// Parallel SIMD evaluate with chunking for large arrays. +/// +/// Splits `src` into chunks, evaluates each in parallel (when the `parallel` +/// feature is enabled), then combines. +#[cfg(feature = "parallel")] +pub fn evaluate_parallel(src: &[F::Scalar]) -> (F::Scalar, F::Scalar) { + use rayon::prelude::*; + + let chunk_size: usize = 32_768; + let lanes = F::LANES; + let step = 8 * lanes; + let chunk_size = chunk_size.div_ceil(step) * step; + + if src.len() <= chunk_size { + let aligned_len = (src.len() / step) * step; + let (mut even, mut odd) = if aligned_len > 0 { + evaluate::(&src[..aligned_len]) + } else { + (F::ZERO, F::ZERO) + }; + for (i, &val) in src.iter().enumerate().skip(aligned_len) { + if i % 2 == 0 { + even = F::scalar_add(even, val); + } else { + odd = F::scalar_add(odd, val); + } + } + return (even, odd); + } + + src.par_chunks(chunk_size) + .map(|chunk| { + let aligned_len = (chunk.len() / step) * step; + if aligned_len == 0 { + let mut even = F::ZERO; + let mut odd = F::ZERO; + for (i, &val) in chunk.iter().enumerate() { + if i % 2 == 0 { + even = F::scalar_add(even, val); + } else { + odd = F::scalar_add(odd, val); + } + } + (even, odd) + } else { + let (e, o) = evaluate::(&chunk[..aligned_len]); + let mut even = e; + let mut odd = o; + for (i, &val) in chunk.iter().enumerate().skip(aligned_len) { + if i % 2 == 0 { + even = F::scalar_add(even, val); + } else { + odd = F::scalar_add(odd, val); + } + } + (even, odd) + } + }) + .reduce( + || (F::ZERO, F::ZERO), + |(e1, o1), (e2, o2)| (F::scalar_add(e1, e2), F::scalar_add(o1, o2)), + ) +} + +/// Non-parallel version of evaluate that handles arbitrary lengths. +#[cfg(not(feature = "parallel"))] +pub fn evaluate_parallel(src: &[F::Scalar]) -> (F::Scalar, F::Scalar) { + let lanes = F::LANES; + let step = 8 * lanes; + let aligned_len = (src.len() / step) * step; + + let (mut even, mut odd) = if aligned_len > 0 { + evaluate::(&src[..aligned_len]) + } else { + (F::ZERO, F::ZERO) + }; + + for i in aligned_len..src.len() { + if i % 2 == 0 { + even = F::scalar_add(even, src[i]); + } else { + odd = F::scalar_add(odd, src[i]); + } + } + + (even, odd) +} + +// ── Product evaluate ──────────────────────────────────────────────────────── + +/// SIMD-vectorized inner product evaluate. +/// +/// Given `f` = `[f(0), f(1), f(2), ...]` and `g` = `[g(0), g(1), g(2), ...]`, +/// computes the coefficients `(a, b)` of the degree-2 round polynomial: +/// a = Σ f[2i] * g[2i] (even-even products) +/// b = Σ (f[2i] * g[2i+1] + f[2i+1] * g[2i]) (cross-term) +/// +/// Uses `load_deinterleaved` + SIMD mul with 4× unrolling. +/// +/// `f` and `g` must have the same length, which must be a multiple of +/// `8 * F::LANES` (4× unroll, each loading 2×LANES from each of f and g). +pub fn product_evaluate( + f: &[F::Scalar], + g: &[F::Scalar], +) -> (F::Scalar, F::Scalar) { + debug_assert_eq!(f.len(), g.len()); + let n = f.len(); + let lanes = F::LANES; + // Each iteration processes 2*LANES elements from each array (one deinterleaved load). + // With 4× unrolling: step = 4 * 2 * LANES = 8 * LANES. + let step = 8 * lanes; + let aligned = (n / step) * step; + + let zero = F::splat(F::ZERO); + let mut acc_a0 = zero; + let mut acc_a1 = zero; + let mut acc_a2 = zero; + let mut acc_a3 = zero; + let mut acc_b0 = zero; + let mut acc_b1 = zero; + let mut acc_b2 = zero; + let mut acc_b3 = zero; + + let f_ptr = f.as_ptr(); + let g_ptr = g.as_ptr(); + + let mut i = 0; + while i < aligned { + unsafe { + // Group 0 + let (fe0, fo0) = F::load_deinterleaved(f_ptr.add(i)); + let (ge0, go0) = F::load_deinterleaved(g_ptr.add(i)); + acc_a0 = F::add(acc_a0, F::mul(fe0, ge0)); + acc_b0 = F::add(acc_b0, F::add(F::mul(fe0, go0), F::mul(fo0, ge0))); + + // Group 1 + let off1 = 2 * lanes; + let (fe1, fo1) = F::load_deinterleaved(f_ptr.add(i + off1)); + let (ge1, go1) = F::load_deinterleaved(g_ptr.add(i + off1)); + acc_a1 = F::add(acc_a1, F::mul(fe1, ge1)); + acc_b1 = F::add(acc_b1, F::add(F::mul(fe1, go1), F::mul(fo1, ge1))); + + // Group 2 + let off2 = 4 * lanes; + let (fe2, fo2) = F::load_deinterleaved(f_ptr.add(i + off2)); + let (ge2, go2) = F::load_deinterleaved(g_ptr.add(i + off2)); + acc_a2 = F::add(acc_a2, F::mul(fe2, ge2)); + acc_b2 = F::add(acc_b2, F::add(F::mul(fe2, go2), F::mul(fo2, ge2))); + + // Group 3 + let off3 = 6 * lanes; + let (fe3, fo3) = F::load_deinterleaved(f_ptr.add(i + off3)); + let (ge3, go3) = F::load_deinterleaved(g_ptr.add(i + off3)); + acc_a3 = F::add(acc_a3, F::mul(fe3, ge3)); + acc_b3 = F::add(acc_b3, F::add(F::mul(fe3, go3), F::mul(fo3, ge3))); + } + i += step; + } + + // Combine accumulators in tree + let total_a = F::add(F::add(acc_a0, acc_a1), F::add(acc_a2, acc_a3)); + let total_b = F::add(F::add(acc_b0, acc_b1), F::add(acc_b2, acc_b3)); + + // Horizontal reduce: sum all lanes into a scalar + let mut buf = [F::ZERO; 32]; + debug_assert!(lanes <= 32); + let mut a_sum = F::ZERO; + let mut b_sum = F::ZERO; + unsafe { F::store(buf.as_mut_ptr(), total_a) }; + for &val in buf.iter().take(lanes) { + a_sum = F::scalar_add(a_sum, val); + } + unsafe { F::store(buf.as_mut_ptr(), total_b) }; + for &val in buf.iter().take(lanes) { + b_sum = F::scalar_add(b_sum, val); + } + + // Scalar tail + let mut i = aligned; + while i + 1 < n { + let fe = f[i]; + let fo = f[i + 1]; + let ge = g[i]; + let go = g[i + 1]; + a_sum = F::scalar_add(a_sum, F::scalar_mul(fe, ge)); + b_sum = F::scalar_add( + b_sum, + F::scalar_add(F::scalar_mul(fe, go), F::scalar_mul(fo, ge)), + ); + i += 2; + } + + (a_sum, b_sum) +} + +/// Parallel SIMD product evaluate with chunking for large arrays. +#[cfg(feature = "parallel")] +pub fn product_evaluate_parallel( + f: &[F::Scalar], + g: &[F::Scalar], +) -> (F::Scalar, F::Scalar) { + use rayon::prelude::*; + + debug_assert_eq!(f.len(), g.len()); + let n = f.len(); + let lanes = F::LANES; + let step = 8 * lanes; + let chunk_size = 32_768_usize.div_ceil(step) * step; + + if n <= chunk_size { + return product_evaluate::(f, g); + } + + // Chunk both f and g in lockstep + f.par_chunks(chunk_size) + .zip(g.par_chunks(chunk_size)) + .map(|(fc, gc)| product_evaluate::(fc, gc)) + .reduce( + || (F::ZERO, F::ZERO), + |(a1, b1), (a2, b2)| (F::scalar_add(a1, a2), F::scalar_add(b1, b2)), + ) +} + +/// Non-parallel fallback. +#[cfg(not(feature = "parallel"))] +pub fn product_evaluate_parallel( + f: &[F::Scalar], + g: &[F::Scalar], +) -> (F::Scalar, F::Scalar) { + product_evaluate::(f, g) +} + +// ── Extension field evaluate ──────────────────────────────────────────────── + +/// SIMD-vectorized pairwise evaluate for extension field elements. +/// +/// Given `src` containing `n` extension elements of degree `d` (total +/// `n * d` base field scalars in AoS layout: `[e0_c0, e0_c1, ..., e1_c0, ...]`), +/// computes: +/// sum_even = e0 + e2 + e4 + ... (component-wise) +/// sum_odd = e1 + e3 + e5 + ... (component-wise) +/// +/// Returns `(even_components, odd_components)` each of length `d`. +/// +/// For degree-1 (base field), use [`evaluate`] instead — it's more optimized. +pub fn ext_evaluate( + src: &[F::Scalar], + ext_degree: usize, +) -> (Vec, Vec) { + let n_elems = src.len() / ext_degree; + debug_assert_eq!(src.len(), n_elems * ext_degree); + + let lanes = F::LANES; + let n_pairs = n_elems / 2; + // Stride in u64s between adjacent extension elements + let elem_stride = ext_degree; + // Stride in u64s between even and odd element in a pair + let pair_stride = 2 * ext_degree; + + let mut even_sums = vec![F::ZERO; ext_degree]; + let mut odd_sums = vec![F::ZERO; ext_degree]; + + // Number of SIMD vectors needed to load one extension element + let vecs_per_elem = ext_degree.div_ceil(lanes); + + if ext_degree >= lanes { + // Optimized path: each extension element is ≥ 1 SIMD vector. + // Use 4× unrolling for ILP (processes 4 pairs per outer iteration). + let unroll = 4; + let aligned_pairs = (n_pairs / unroll) * unroll; + + let simd_components = ext_degree / lanes; + // 4 even + 4 odd accumulators, each with simd_components vectors + let zero = F::splat(F::ZERO); + let mut even_accs = [[zero; 8]; 4]; // [unroll][max_simd_components] + let mut odd_accs = [[zero; 8]; 4]; + debug_assert!(simd_components <= 8); + + let ptr = src.as_ptr(); + let mut pair = 0; + while pair < aligned_pairs { + for u in 0..unroll { + let p = pair + u; + let even_off = p * pair_stride; + let odd_off = even_off + elem_stride; + for c in 0..simd_components { + unsafe { + even_accs[u][c] = + F::add(even_accs[u][c], F::load(ptr.add(even_off + c * lanes))); + odd_accs[u][c] = + F::add(odd_accs[u][c], F::load(ptr.add(odd_off + c * lanes))); + } + } + } + pair += unroll; + } + + // Combine unrolled accumulators + for u in 1..unroll { + for c in 0..simd_components { + even_accs[0][c] = F::add(even_accs[0][c], even_accs[u][c]); + odd_accs[0][c] = F::add(odd_accs[0][c], odd_accs[u][c]); + } + } + + // Tail: remaining pairs (< unroll) + while pair < n_pairs { + let even_off = pair * pair_stride; + let odd_off = even_off + elem_stride; + for c in 0..simd_components { + unsafe { + even_accs[0][c] = + F::add(even_accs[0][c], F::load(ptr.add(even_off + c * lanes))); + odd_accs[0][c] = F::add(odd_accs[0][c], F::load(ptr.add(odd_off + c * lanes))); + } + } + pair += 1; + } + + // Extract SIMD lanes into scalar sums + let mut buf = [F::ZERO; 32]; + for c in 0..simd_components { + unsafe { F::store(buf.as_mut_ptr(), even_accs[0][c]) }; + for l in 0..lanes { + even_sums[c * lanes + l] = F::scalar_add(even_sums[c * lanes + l], buf[l]); + } + unsafe { F::store(buf.as_mut_ptr(), odd_accs[0][c]) }; + for l in 0..lanes { + odd_sums[c * lanes + l] = F::scalar_add(odd_sums[c * lanes + l], buf[l]); + } + } + + // Tail components (ext_degree not divisible by lanes) + let tail_start = simd_components * lanes; + for p in 0..n_pairs { + let even_off = p * pair_stride; + let odd_off = even_off + elem_stride; + for c in tail_start..ext_degree { + even_sums[c] = F::scalar_add(even_sums[c], src[even_off + c]); + odd_sums[c] = F::scalar_add(odd_sums[c], src[odd_off + c]); + } + } + } else { + // ext_degree < LANES (e.g., degree-2 with AVX-512 LANES=8): + // Multiple extension elements fit in one SIMD vector. + // Scalar accumulation — still fast for small n. + let _ = vecs_per_elem; + for p in 0..n_pairs { + let even_off = p * pair_stride; + let odd_off = even_off + elem_stride; + for c in 0..ext_degree { + even_sums[c] = F::scalar_add(even_sums[c], src[even_off + c]); + odd_sums[c] = F::scalar_add(odd_sums[c], src[odd_off + c]); + } + } + } + + (even_sums, odd_sums) +} + +/// Specialized ext3 evaluate using NEON `vld3q_u64` structured loads. +/// +/// Loads 2 extension elements (6 u64s) at a time, deinterleaving by stride 3 +/// into per-component vectors. No scalar tail needed — every u64 is SIMD-processed. +/// +/// 4× unrolled: processes 8 ext3 elements (4 even + 4 odd) per outer iteration. +#[cfg(target_arch = "aarch64")] +pub fn ext3_evaluate_neon(src: &[u64]) -> (Vec, Vec) { + use crate::simd_fields::goldilocks::neon::GoldilocksNeon; + use crate::simd_fields::SimdBaseField; + use core::arch::aarch64::*; + + let ext_deg = 3; + let n_elems = src.len() / ext_deg; + let n_pairs = n_elems / 2; + + // Each pair = 2 ext3 elements = 6 u64s. + // vld3q_u64 loads 6 u64s and deinterleaves into 3 uint64x2_t: + // v[0] = [c0_even, c0_odd], v[1] = [c1_even, c1_odd], v[2] = [c2_even, c2_odd] + // Then lane 0 = even element's component, lane 1 = odd element's component. + + let zero = GoldilocksNeon::splat(0); + // 4× unrolled accumulators for even (lane 0) and odd (lane 1) + let mut acc_c0_0 = zero; + let mut acc_c1_0 = zero; + let mut acc_c2_0 = zero; + let mut acc_c0_1 = zero; + let mut acc_c1_1 = zero; + let mut acc_c2_1 = zero; + let mut acc_c0_2 = zero; + let mut acc_c1_2 = zero; + let mut acc_c2_2 = zero; + let mut acc_c0_3 = zero; + let mut acc_c1_3 = zero; + let mut acc_c2_3 = zero; + + let unroll = 4; + let aligned_pairs = (n_pairs / unroll) * unroll; + let ptr = src.as_ptr(); + + let mut pair = 0; + while pair < aligned_pairs { + unsafe { + // Group 0 + let v0 = vld3q_u64(ptr.add((pair) * 6)); + acc_c0_0 = GoldilocksNeon::add(acc_c0_0, v0.0); + acc_c1_0 = GoldilocksNeon::add(acc_c1_0, v0.1); + acc_c2_0 = GoldilocksNeon::add(acc_c2_0, v0.2); + + // Group 1 + let v1 = vld3q_u64(ptr.add((pair + 1) * 6)); + acc_c0_1 = GoldilocksNeon::add(acc_c0_1, v1.0); + acc_c1_1 = GoldilocksNeon::add(acc_c1_1, v1.1); + acc_c2_1 = GoldilocksNeon::add(acc_c2_1, v1.2); + + // Group 2 + let v2 = vld3q_u64(ptr.add((pair + 2) * 6)); + acc_c0_2 = GoldilocksNeon::add(acc_c0_2, v2.0); + acc_c1_2 = GoldilocksNeon::add(acc_c1_2, v2.1); + acc_c2_2 = GoldilocksNeon::add(acc_c2_2, v2.2); + + // Group 3 + let v3 = vld3q_u64(ptr.add((pair + 3) * 6)); + acc_c0_3 = GoldilocksNeon::add(acc_c0_3, v3.0); + acc_c1_3 = GoldilocksNeon::add(acc_c1_3, v3.1); + acc_c2_3 = GoldilocksNeon::add(acc_c2_3, v3.2); + } + pair += unroll; + } + + // Combine unrolled accumulators + let mut total_c0 = GoldilocksNeon::add( + GoldilocksNeon::add(acc_c0_0, acc_c0_1), + GoldilocksNeon::add(acc_c0_2, acc_c0_3), + ); + let mut total_c1 = GoldilocksNeon::add( + GoldilocksNeon::add(acc_c1_0, acc_c1_1), + GoldilocksNeon::add(acc_c1_2, acc_c1_3), + ); + let mut total_c2 = GoldilocksNeon::add( + GoldilocksNeon::add(acc_c2_0, acc_c2_1), + GoldilocksNeon::add(acc_c2_2, acc_c2_3), + ); + + // Tail pairs + while pair < n_pairs { + unsafe { + let v = vld3q_u64(ptr.add(pair * 6)); + total_c0 = GoldilocksNeon::add(total_c0, v.0); + total_c1 = GoldilocksNeon::add(total_c1, v.1); + total_c2 = GoldilocksNeon::add(total_c2, v.2); + } + pair += 1; + } + + // Extract: lane 0 = even sum, lane 1 = odd sum for each component + let mut buf = [0u64; 2]; + let mut even = vec![0u64; 3]; + let mut odd = vec![0u64; 3]; + + unsafe { + GoldilocksNeon::store(buf.as_mut_ptr(), total_c0); + even[0] = buf[0]; + odd[0] = buf[1]; + GoldilocksNeon::store(buf.as_mut_ptr(), total_c1); + even[1] = buf[0]; + odd[1] = buf[1]; + GoldilocksNeon::store(buf.as_mut_ptr(), total_c2); + even[2] = buf[0]; + odd[2] = buf[1]; + } + + (even, odd) +} + +/// Parallel extension evaluate with chunking for large arrays. +#[cfg(feature = "parallel")] +pub fn ext_evaluate_parallel( + src: &[F::Scalar], + ext_degree: usize, +) -> (Vec, Vec) { + use rayon::prelude::*; + + let n_elems = src.len() / ext_degree; + let pair_stride = 2 * ext_degree; + let chunk_pairs = 8192_usize; + let chunk_u64s = chunk_pairs * pair_stride; + let n_pairs = n_elems / 2; + + if n_pairs <= chunk_pairs { + return ext_evaluate::(src, ext_degree); + } + + src.par_chunks(chunk_u64s) + .map(|chunk| ext_evaluate::(chunk, ext_degree)) + .reduce( + || (vec![F::ZERO; ext_degree], vec![F::ZERO; ext_degree]), + |(mut e1, mut o1), (e2, o2)| { + for i in 0..ext_degree { + e1[i] = F::scalar_add(e1[i], e2[i]); + o1[i] = F::scalar_add(o1[i], o2[i]); + } + (e1, o1) + }, + ) +} + +/// Non-parallel fallback. +#[cfg(not(feature = "parallel"))] +pub fn ext_evaluate_parallel( + src: &[F::Scalar], + ext_degree: usize, +) -> (Vec, Vec) { + ext_evaluate::(src, ext_degree) +} + +// Note: Extension field REDUCE (multiply by challenge) stays in the generic +// arkworks path. The extension multiply is complex (Karatsuba with base muls) +// and on NEON the base mul is scalar anyway. The SIMD win for extensions is +// in the EVALUATE (addition only). For AVX-512 where base mul is truly +// vectorized, a SIMD extension reduce would help — future work. + +#[cfg(test)] +#[cfg(any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") +))] +mod tests { + use super::*; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + use crate::simd_fields::goldilocks::avx512::GoldilocksAvx512 as Backend; + #[cfg(target_arch = "aarch64")] + use crate::simd_fields::goldilocks::neon::GoldilocksNeon as Backend; + use crate::tests::F64; + use ark_ff::UniformRand; + use ark_std::test_rng; + + #[test] + fn test_evaluate_matches_pairwise() { + use crate::reductions::pairwise; + + let mut rng = test_rng(); + let n = 1 << 16; + let evals_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let evals_raw: Vec = evals_ff.iter().map(|f| (*f).value).collect(); + + // Reference: arkworks pairwise evaluate + let (expected_even, expected_odd) = pairwise::evaluate(&evals_ff); + + // SIMD evaluate (Montgomery domain) + let (simd_even, simd_odd) = evaluate::(&evals_raw); + + assert_eq!(expected_even.value, simd_even, "even sum mismatch"); + assert_eq!(expected_odd.value, simd_odd, "odd sum mismatch"); + } + + #[test] + fn test_evaluate_parallel_matches_pairwise() { + use crate::reductions::pairwise; + + let mut rng = test_rng(); + let n = 1 << 20; + let evals_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let evals_raw: Vec = evals_ff.iter().map(|f| (*f).value).collect(); + + let (expected_even, expected_odd) = pairwise::evaluate(&evals_ff); + let (simd_even, simd_odd) = evaluate_parallel::(&evals_raw); + + assert_eq!(expected_even.value, simd_even, "parallel even sum mismatch"); + assert_eq!(expected_odd.value, simd_odd, "parallel odd sum mismatch"); + } + + #[test] + fn test_product_evaluate_matches_generic() { + use crate::reductions::pairwise::pairwise_product_evaluate; + + let mut rng = test_rng(); + let n = 1 << 16; + let f_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let g_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let f_raw: Vec = f_ff.iter().map(|f| (*f).value).collect(); + let g_raw: Vec = g_ff.iter().map(|g| (*g).value).collect(); + + let (expected_a, expected_b) = pairwise_product_evaluate(&[f_ff.clone(), g_ff.clone()]); + + let (simd_a, simd_b) = product_evaluate::(&f_raw, &g_raw); + + assert_eq!(expected_a.value, simd_a, "product a mismatch"); + assert_eq!(expected_b.value, simd_b, "product b mismatch"); + } + + #[test] + fn test_product_evaluate_parallel_matches_generic() { + use crate::reductions::pairwise::pairwise_product_evaluate; + + let mut rng = test_rng(); + let n = 1 << 20; + let f_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let g_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let f_raw: Vec = f_ff.iter().map(|f| (*f).value).collect(); + let g_raw: Vec = g_ff.iter().map(|g| (*g).value).collect(); + + let (expected_a, expected_b) = pairwise_product_evaluate(&[f_ff.clone(), g_ff.clone()]); + + let (simd_a, simd_b) = product_evaluate_parallel::(&f_raw, &g_raw); + + assert_eq!(expected_a.value, simd_a, "parallel product a mismatch"); + assert_eq!(expected_b.value, simd_b, "parallel product b mismatch"); + } +} diff --git a/src/simd_sumcheck/mod.rs b/src/simd_sumcheck/mod.rs new file mode 100644 index 00000000..9b859689 --- /dev/null +++ b/src/simd_sumcheck/mod.rs @@ -0,0 +1,7 @@ +//! SIMD-vectorized sumcheck algorithm layer. +//! +//! Generic over [`SimdBaseField`](super::simd_fields::SimdBaseField). + +pub(crate) mod dispatch; +pub mod evaluate; +pub mod reduce; diff --git a/src/simd_sumcheck/reduce.rs b/src/simd_sumcheck/reduce.rs new file mode 100644 index 00000000..9ded3502 --- /dev/null +++ b/src/simd_sumcheck/reduce.rs @@ -0,0 +1,956 @@ +#![allow(dead_code)] +//! SIMD-vectorized reduce kernels: fold evaluations with a challenge. +//! +//! Two layout variants: +//! - **Half-split (MSB)**: pairs `data[k]` with `data[k + L/2]`. This is +//! the layout used by the public sumcheck entry points and by WHIR. +//! - **Pair-split (LSB)**: pairs `data[2k]` with `data[2k+1]`. Used by the +//! legacy `Prover` trait and `coefficient_sumcheck`. +//! +//! The MSB kernel (`reduce_msb_in_place`) uses plain contiguous `F::load` +//! from each half — simpler and faster than the LSB `load_deinterleaved`. + +use crate::simd_fields::SimdBaseField; + +// ═══════════════════════════════════════════════════════════════════════════ +// Half-split (MSB) reduce +// ═══════════════════════════════════════════════════════════════════════════ + +/// SIMD-vectorized MSB (half-split) reduce, in-place. +/// +/// `new[k] = src[k] + challenge * (src[k + half] − src[k])` for `k` in +/// `0..half`, where `half = next_power_of_two(n) / 2`. Elements in the low +/// half beyond `n − half` (the "tail") have no partner in the high half and +/// are folded as `src[k] * (1 − challenge)`. +/// +/// Returns the output length `half`. +pub fn reduce_msb_in_place(src: &mut [F::Scalar], challenge: F::Scalar) -> usize { + let n = src.len(); + if n <= 1 { + return n; + } + + let half = n.next_power_of_two() >> 1; + let paired = n - half; // elements that have a partner in the high half + let lanes = F::LANES; + let challenge_v = F::splat(challenge); + + // ── SIMD main loop over paired portion ── + let step = 4 * lanes; + let aligned = (paired / step) * step; + + let lo_ptr = src.as_ptr(); + let hi_ptr = unsafe { src.as_ptr().add(half) }; + let out_ptr = src.as_mut_ptr(); + + let mut i = 0; + while i < aligned { + unsafe { + for g in 0..4 { + let off = i + g * lanes; + let a = F::load(lo_ptr.add(off)); + let b = F::load(hi_ptr.add(off)); + let r = F::add(a, F::mul(challenge_v, F::sub(b, a))); + F::store(out_ptr.add(off), r); + } + } + i += step; + } + + // ── Scalar tail of paired portion ── + while i < paired { + let a = src[i]; + let b = src[i + half]; + src[i] = F::scalar_add(a, F::scalar_mul(challenge, F::scalar_sub(b, a))); + i += 1; + } + + // ── Unpaired tail: data[k] *= (1 − challenge) for k in paired..half ── + let one_minus = F::scalar_sub(F::ONE, challenge); + for v in src.iter_mut().take(half).skip(paired) { + *v = F::scalar_mul(*v, one_minus); + } + + half +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Pair-split (LSB) reduce — legacy, used by coefficient_sumcheck +// ═══════════════════════════════════════════════════════════════════════════ + +/// SIMD-vectorized pairwise reduce, producing a new Vec. +/// +/// Uses 4× loop unrolling for instruction-level parallelism. +/// (8× was benchmarked but regressed due to register pressure from mul.) +/// Stack-allocated deinterleave buffers avoid per-iteration heap allocation. +pub fn reduce_both_in_place( + f: &mut [F::Scalar], + g: &mut [F::Scalar], + challenge: F::Scalar, +) -> usize { + let n = f.len() / 2; + debug_assert_eq!(f.len(), g.len()); + let lanes = F::LANES; + let challenge_v = F::splat(challenge); + let step = 4 * lanes; + let aligned = (n / step) * step; + + let f_ptr = f.as_ptr(); + let g_ptr = g.as_ptr(); + let f_out = f.as_mut_ptr(); + let g_out = g.as_mut_ptr(); + + let mut i = 0; + while i < aligned { + unsafe { + for u in 0..4 { + let off = i + u * lanes; + + let (fv_a, fv_b) = F::load_deinterleaved(f_ptr.add(2 * off)); + let f_red = F::add(fv_a, F::mul(challenge_v, F::sub(fv_b, fv_a))); + F::store(f_out.add(off), f_red); + + let (gv_a, gv_b) = F::load_deinterleaved(g_ptr.add(2 * off)); + let g_red = F::add(gv_a, F::mul(challenge_v, F::sub(gv_b, gv_a))); + F::store(g_out.add(off), g_red); + } + } + i += step; + } + + while i < n { + let fa = f[2 * i]; + let fb = f[2 * i + 1]; + f[i] = F::scalar_add(fa, F::scalar_mul(challenge, F::scalar_sub(fb, fa))); + + let ga = g[2 * i]; + let gb = g[2 * i + 1]; + g[i] = F::scalar_add(ga, F::scalar_mul(challenge, F::scalar_sub(gb, ga))); + + i += 1; + } + + n +} + +/// SIMD-vectorized pairwise reduce, in-place. +/// +/// Reads pairs from the first `2*n` positions, writes results to `src[0..n]`. +/// Returns the output length `n`. +pub fn reduce_in_place(src: &mut [F::Scalar], challenge: F::Scalar) -> usize { + let n = src.len() / 2; + let lanes = F::LANES; + let challenge_v = F::splat(challenge); + let step = 4 * lanes; // 4× unroll: 4 groups of LANES outputs per iteration + let aligned = (n / step) * step; + + let src_ptr = src.as_ptr(); + let out_ptr = src.as_mut_ptr(); + + let mut i = 0; + while i < aligned { + unsafe { + for g in 0..4 { + let (av, bv) = F::load_deinterleaved(src_ptr.add(2 * (i + g * lanes))); + let r = F::add(av, F::mul(challenge_v, F::sub(bv, av))); + F::store(out_ptr.add(i + g * lanes), r); + } + } + i += step; + } + + while i + lanes <= n { + unsafe { + let (av, bv) = F::load_deinterleaved(src_ptr.add(2 * i)); + let r = F::add(av, F::mul(challenge_v, F::sub(bv, av))); + F::store(src[i..].as_mut_ptr(), r); + } + i += lanes; + } + + while i < n { + let a = src[2 * i]; + let b = src[2 * i + 1]; + let diff = F::scalar_sub(b, a); + let scaled = F::scalar_mul(challenge, diff); + src[i] = F::scalar_add(a, scaled); + i += 1; + } + + n +} + +/// Fused reduce + evaluate for the next round. +/// +/// Performs in-place pairwise reduce (same as `reduce_in_place`) and simultaneously +/// accumulates the even/odd sums that `evaluate` would compute on the reduced output. +/// This eliminates one full data pass per round (the separate evaluate read). +/// +/// Returns `(next_even_sum, next_odd_sum, output_length)`. +pub fn reduce_and_evaluate( + src: &mut [F::Scalar], + challenge: F::Scalar, +) -> (F::Scalar, F::Scalar, usize) { + let n = src.len() / 2; + let lanes = F::LANES; + let challenge_v = F::splat(challenge); + + // We need 2 groups of accumulators: one for reduced values at even output + // positions and one for odd. Within a contiguous vector of LANES elements + // written at output position i, lanes 0,2,4,6 are "even" and 1,3,5,7 are + // "odd" when considered as part of the flat output array (since i is always + // aligned to LANES). So we just accumulate all reduced vectors and separate + // even/odd lanes at the end — exactly like evaluate does. + // + // Use lazy accumulation: wrapping add + carry count, finalize at the end. + // This halves the accumulation overhead (3 instructions vs 6 for full mod add). + let zero = F::splat(F::ZERO); + let mut acc0 = zero; + let mut acc1 = zero; + let mut acc2 = zero; + let mut acc3 = zero; + let mut carry0 = zero; + let mut carry1 = zero; + let mut carry2 = zero; + let mut carry3 = zero; + + let step = 4 * lanes; + let aligned = (n / step) * step; + + let src_ptr = src.as_ptr(); + let out_ptr = src.as_mut_ptr(); + + let mut i = 0; + while i < aligned { + unsafe { + let (av0, bv0) = F::load_deinterleaved(src_ptr.add(2 * i)); + let r0 = F::add(av0, F::mul(challenge_v, F::sub(bv0, av0))); + F::store(out_ptr.add(i), r0); + let sum0 = F::add_wrapping(acc0, r0); + carry0 = F::add_wrapping(carry0, F::carry_mask(sum0, acc0)); + acc0 = sum0; + + let (av1, bv1) = F::load_deinterleaved(src_ptr.add(2 * (i + lanes))); + let r1 = F::add(av1, F::mul(challenge_v, F::sub(bv1, av1))); + F::store(out_ptr.add(i + lanes), r1); + let sum1 = F::add_wrapping(acc1, r1); + carry1 = F::add_wrapping(carry1, F::carry_mask(sum1, acc1)); + acc1 = sum1; + + let (av2, bv2) = F::load_deinterleaved(src_ptr.add(2 * (i + 2 * lanes))); + let r2 = F::add(av2, F::mul(challenge_v, F::sub(bv2, av2))); + F::store(out_ptr.add(i + 2 * lanes), r2); + let sum2 = F::add_wrapping(acc2, r2); + carry2 = F::add_wrapping(carry2, F::carry_mask(sum2, acc2)); + acc2 = sum2; + + let (av3, bv3) = F::load_deinterleaved(src_ptr.add(2 * (i + 3 * lanes))); + let r3 = F::add(av3, F::mul(challenge_v, F::sub(bv3, av3))); + F::store(out_ptr.add(i + 3 * lanes), r3); + let sum3 = F::add_wrapping(acc3, r3); + carry3 = F::add_wrapping(carry3, F::carry_mask(sum3, acc3)); + acc3 = sum3; + } + i += step; + } + + // Cleanup: single vector at a time (use full modular add — few iterations) + while i + lanes <= n { + unsafe { + let (av, bv) = F::load_deinterleaved(src_ptr.add(2 * i)); + let r = F::add(av, F::mul(challenge_v, F::sub(bv, av))); + F::store(out_ptr.add(i), r); + acc0 = F::add(acc0, r); + } + i += lanes; + } + + // Finalize lazy accumulators: correct for carries + let red0 = F::reduce_carry(acc0, carry0); + let red1 = F::reduce_carry(acc1, carry1); + let red2 = F::reduce_carry(acc2, carry2); + let red3 = F::reduce_carry(acc3, carry3); + + // Combine in a tree for ILP + let total = F::add(F::add(red0, red1), F::add(red2, red3)); + + // Extract lanes and sum even/odd groups + let mut lanes_buf = [F::ZERO; 32]; + debug_assert!(F::LANES <= 32); + unsafe { F::store(lanes_buf.as_mut_ptr(), total) }; + + let mut even_sum = F::ZERO; + let mut odd_sum = F::ZERO; + for (j, &val) in lanes_buf.iter().enumerate().take(F::LANES) { + if j % 2 == 0 { + even_sum = F::scalar_add(even_sum, val); + } else { + odd_sum = F::scalar_add(odd_sum, val); + } + } + + // Scalar tail (both reduce and accumulate) + while i < n { + let a = src[2 * i]; + let b = src[2 * i + 1]; + let diff = F::scalar_sub(b, a); + let scaled = F::scalar_mul(challenge, diff); + let r = F::scalar_add(a, scaled); + src[i] = r; + if i % 2 == 0 { + even_sum = F::scalar_add(even_sum, r); + } else { + odd_sum = F::scalar_add(odd_sum, r); + } + i += 1; + } + + (even_sum, odd_sum, n) +} +#[allow(dead_code)] +#[cfg_attr( + not(any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + )), + allow(unused_variables) +)] +pub fn ext2_reduce_in_place>( + src: &mut [u64], + challenge: [u64; 2], + w: u64, +) -> usize { + let ext_deg = 2; + let n_elems = src.len() / ext_deg; + let n_pairs = n_elems / 2; + + #[cfg(target_arch = "aarch64")] + { + use crate::simd_fields::goldilocks::neon::{ext2_scalar_mul, GoldilocksNeon}; + + let _w_vec = GoldilocksNeon::splat(w); + let _chg_v = [ + GoldilocksNeon::splat(challenge[0]), + GoldilocksNeon::splat(challenge[1]), + ]; + + // With NEON LANES=2 and degree-2: one SIMD load = one extension element. + // Process pairs: load even (2 u64s), load odd (2 u64s), compute result. + let ptr = src.as_mut_ptr(); + for i in 0..n_pairs { + let a_off = (2 * i) * ext_deg; + let b_off = (2 * i + 1) * ext_deg; + let out_off = i * ext_deg; + + unsafe { + // Load even and odd extension elements + let a_v = GoldilocksNeon::load(ptr.add(a_off) as *const u64); + let b_v = GoldilocksNeon::load(ptr.add(b_off) as *const u64); + + // diff = b - a (component-wise, both components in one SIMD op) + let diff_v = GoldilocksNeon::sub(b_v, a_v); + + // For ext2 multiply, we need SoA: separate c0 and c1 components. + // With LANES=2, the vector holds [c0, c1] — need to broadcast + // each component to both lanes for the multiply. + // Actually, ext2_mul expects [Packed; 2] where each Packed has + // the same component from multiple elements. With only 1 element + // per SIMD vector, we just extract and use scalar. + let diff0 = core::arch::aarch64::vgetq_lane_u64(diff_v, 0); + let diff1 = core::arch::aarch64::vgetq_lane_u64(diff_v, 1); + let prod = ext2_scalar_mul([diff0, diff1], challenge, w); + + // result = a + prod (component-wise) + let a0 = core::arch::aarch64::vgetq_lane_u64(a_v, 0); + let a1 = core::arch::aarch64::vgetq_lane_u64(a_v, 1); + let r0 = GoldilocksNeon::scalar_add(a0, prod[0]); + let r1 = GoldilocksNeon::scalar_add(a1, prod[1]); + + *ptr.add(out_off) = r0; + *ptr.add(out_off + 1) = r1; + } + } + } + + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + { + use crate::simd_fields::goldilocks::avx512::{ + ext2_reduce_8pairs, ext2_scalar_mul, GoldilocksAvx512, + }; + + let challenge_c0 = GoldilocksAvx512::splat(challenge[0]); + let challenge_c1 = GoldilocksAvx512::splat(challenge[1]); + let w_vec = GoldilocksAvx512::splat(w); + + let ptr = src.as_mut_ptr(); + let simd_pairs = (n_pairs / 8) * 8; + let mut i = 0; + + // Safe in-place: ext2_reduce_8pairs loads all 32 u64s into registers + // before writing 16 u64s, and output region is always <= input region. + while i < simd_pairs { + let src_off = (2 * i) * ext_deg; + let out_off = i * ext_deg; + unsafe { + ext2_reduce_8pairs( + ptr.add(src_off) as *const u64, + ptr.add(out_off), + challenge_c0, + challenge_c1, + w_vec, + ); + } + i += 8; + } + + while i < n_pairs { + let a_off = (2 * i) * ext_deg; + let b_off = (2 * i + 1) * ext_deg; + let out_off = i * ext_deg; + + let diff = [ + GoldilocksAvx512::scalar_sub(src[b_off], src[a_off]), + GoldilocksAvx512::scalar_sub(src[b_off + 1], src[a_off + 1]), + ]; + let prod = ext2_scalar_mul(diff, challenge, w); + + src[out_off] = GoldilocksAvx512::scalar_add(src[a_off], prod[0]); + src[out_off + 1] = GoldilocksAvx512::scalar_add(src[a_off + 1], prod[1]); + i += 1; + } + } + + n_pairs * ext_deg +} +#[cfg_attr( + not(any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") + )), + allow(unused_variables) +)] +pub fn ext3_reduce_in_place>( + src: &mut [u64], + challenge: [u64; 3], + w: u64, +) -> usize { + let ext_deg = 3; + let n_elems = src.len() / ext_deg; + let n_pairs = n_elems / 2; + + #[cfg(target_arch = "aarch64")] + { + use crate::simd_fields::goldilocks::neon::{ext3_scalar_mul, GoldilocksNeon}; + + for i in 0..n_pairs { + let a_off = (2 * i) * ext_deg; + let b_off = (2 * i + 1) * ext_deg; + let out_off = i * ext_deg; + + let diff = [ + GoldilocksNeon::scalar_sub(src[b_off], src[a_off]), + GoldilocksNeon::scalar_sub(src[b_off + 1], src[a_off + 1]), + GoldilocksNeon::scalar_sub(src[b_off + 2], src[a_off + 2]), + ]; + let prod = ext3_scalar_mul(diff, challenge, w); + src[out_off] = GoldilocksNeon::scalar_add(src[a_off], prod[0]); + src[out_off + 1] = GoldilocksNeon::scalar_add(src[a_off + 1], prod[1]); + src[out_off + 2] = GoldilocksNeon::scalar_add(src[a_off + 2], prod[2]); + } + } + + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + { + use crate::simd_fields::goldilocks::avx512::{ + ext3_reduce_8pairs, ext3_scalar_mul, GoldilocksAvx512, + }; + + let challenge_v = [ + GoldilocksAvx512::splat(challenge[0]), + GoldilocksAvx512::splat(challenge[1]), + GoldilocksAvx512::splat(challenge[2]), + ]; + let w_vec = GoldilocksAvx512::splat(w); + + let ptr = src.as_mut_ptr(); + let simd_pairs = (n_pairs / 8) * 8; + let mut i = 0; + + // Safe in-place: ext3_reduce_8pairs gathers all 48 u64s into registers + // before scattering 24 u64s, and output region is always <= input region. + while i < simd_pairs { + let src_off = (2 * i) * ext_deg; + let out_off = i * ext_deg; + unsafe { + ext3_reduce_8pairs( + ptr.add(src_off) as *const u64, + ptr.add(out_off), + challenge_v, + w_vec, + ); + } + i += 8; + } + + while i < n_pairs { + let a_off = (2 * i) * ext_deg; + let b_off = (2 * i + 1) * ext_deg; + let out_off = i * ext_deg; + + let diff = [ + GoldilocksAvx512::scalar_sub(src[b_off], src[a_off]), + GoldilocksAvx512::scalar_sub(src[b_off + 1], src[a_off + 1]), + GoldilocksAvx512::scalar_sub(src[b_off + 2], src[a_off + 2]), + ]; + let prod = ext3_scalar_mul(diff, challenge, w); + src[out_off] = GoldilocksAvx512::scalar_add(src[a_off], prod[0]); + src[out_off + 1] = GoldilocksAvx512::scalar_add(src[a_off + 1], prod[1]); + src[out_off + 2] = GoldilocksAvx512::scalar_add(src[a_off + 2], prod[2]); + i += 1; + } + } + + n_pairs * ext_deg +} +/// SoA ext2 inner product evaluate. +/// +/// Given `f` and `g` as ext2 elements in SoA layout (f_c0, f_c1, g_c0, g_c1), +/// computes the degree-2 round polynomial coefficients `(a, b)`: +/// a = Σ f[2i] * g[2i] (ext2 products) +/// b = Σ (f[2i] * g[2i+1] + f[2i+1] * g[2i]) (ext2 cross-terms) +/// +/// Returns `(a_c0, a_c1, b_c0, b_c1)` as raw u64 components. +pub fn ext2_soa_product_evaluate>( + f_c0: &[u64], + f_c1: &[u64], + g_c0: &[u64], + g_c1: &[u64], + w: u64, +) -> ([u64; 2], [u64; 2]) { + let n = f_c0.len(); + debug_assert_eq!(n, f_c1.len()); + debug_assert_eq!(n, g_c0.len()); + debug_assert_eq!(n, g_c1.len()); + + let lanes = F::LANES; + // Each load_deinterleaved consumes 2*lanes u64s of input (covering `lanes` pairs). + // 2× unroll → each iteration consumes 4*lanes u64s. + let load_width = 2 * lanes; + let step = 2 * load_width; // 4 * lanes + let aligned = (n / step) * step; + let w_vec = F::splat(w); + + let zero = F::splat(F::ZERO); + let mut acc_a0 = zero; + let mut acc_a1 = zero; + let mut acc_b0 = zero; + let mut acc_b1 = zero; + + let mut i = 0; + while i < aligned { + unsafe { + for u in 0..2 { + let off = i + u * load_width; + let (fe0, fo0) = F::load_deinterleaved(f_c0.as_ptr().add(off)); + let (fe1, fo1) = F::load_deinterleaved(f_c1.as_ptr().add(off)); + let (ge0, go0) = F::load_deinterleaved(g_c0.as_ptr().add(off)); + let (ge1, go1) = F::load_deinterleaved(g_c1.as_ptr().add(off)); + + // a += f_even * g_even (ext2 Karatsuba) + let v0 = F::mul(fe0, ge0); + let v1 = F::mul(fe1, ge1); + acc_a0 = F::add(acc_a0, F::add(v0, F::mul(w_vec, v1))); + let m = F::mul(F::add(fe0, fe1), F::add(ge0, ge1)); + acc_a1 = F::add(acc_a1, F::sub(F::sub(m, v0), v1)); + + // b += f_even * g_odd (ext2 Karatsuba) + let u0 = F::mul(fe0, go0); + let u1 = F::mul(fe1, go1); + let m1 = F::mul(F::add(fe0, fe1), F::add(go0, go1)); + // b += f_odd * g_even (ext2 Karatsuba) + let p0 = F::mul(fo0, ge0); + let p1 = F::mul(fo1, ge1); + let m2 = F::mul(F::add(fo0, fo1), F::add(ge0, ge1)); + + acc_b0 = F::add( + acc_b0, + F::add(F::add(u0, F::mul(w_vec, u1)), F::add(p0, F::mul(w_vec, p1))), + ); + acc_b1 = F::add( + acc_b1, + F::add(F::sub(F::sub(m1, u0), u1), F::sub(F::sub(m2, p0), p1)), + ); + } + } + i += step; + } + + // Remaining SIMD vectors (one load_width at a time) + while i + load_width <= n { + unsafe { + let (fe0, fo0) = F::load_deinterleaved(f_c0.as_ptr().add(i)); + let (fe1, fo1) = F::load_deinterleaved(f_c1.as_ptr().add(i)); + let (ge0, go0) = F::load_deinterleaved(g_c0.as_ptr().add(i)); + let (ge1, go1) = F::load_deinterleaved(g_c1.as_ptr().add(i)); + + let v0 = F::mul(fe0, ge0); + let v1 = F::mul(fe1, ge1); + acc_a0 = F::add(acc_a0, F::add(v0, F::mul(w_vec, v1))); + let m = F::mul(F::add(fe0, fe1), F::add(ge0, ge1)); + acc_a1 = F::add(acc_a1, F::sub(F::sub(m, v0), v1)); + + let u0 = F::mul(fe0, go0); + let u1 = F::mul(fe1, go1); + let m1 = F::mul(F::add(fe0, fe1), F::add(go0, go1)); + let p0 = F::mul(fo0, ge0); + let p1 = F::mul(fo1, ge1); + let m2 = F::mul(F::add(fo0, fo1), F::add(ge0, ge1)); + + acc_b0 = F::add( + acc_b0, + F::add(F::add(u0, F::mul(w_vec, u1)), F::add(p0, F::mul(w_vec, p1))), + ); + acc_b1 = F::add( + acc_b1, + F::add(F::sub(F::sub(m1, u0), u1), F::sub(F::sub(m2, p0), p1)), + ); + } + i += load_width; + } + + // Horizontal reduce + let mut buf = [F::ZERO; 32]; + let mut a = [F::ZERO; 2]; + let mut b = [F::ZERO; 2]; + + unsafe { F::store(buf.as_mut_ptr(), acc_a0) }; + for &v in buf.iter().take(lanes) { + a[0] = F::scalar_add(a[0], v); + } + unsafe { F::store(buf.as_mut_ptr(), acc_a1) }; + for &v in buf.iter().take(lanes) { + a[1] = F::scalar_add(a[1], v); + } + unsafe { F::store(buf.as_mut_ptr(), acc_b0) }; + for &v in buf.iter().take(lanes) { + b[0] = F::scalar_add(b[0], v); + } + unsafe { F::store(buf.as_mut_ptr(), acc_b1) }; + for &v in buf.iter().take(lanes) { + b[1] = F::scalar_add(b[1], v); + } + + // Scalar tail + while i + 1 < n { + let fe = [f_c0[i], f_c1[i]]; + let fo = [f_c0[i + 1], f_c1[i + 1]]; + let ge = [g_c0[i], g_c1[i]]; + let go_ = [g_c0[i + 1], g_c1[i + 1]]; + + // a += fe * ge + let v0 = F::scalar_mul(fe[0], ge[0]); + let v1 = F::scalar_mul(fe[1], ge[1]); + a[0] = F::scalar_add(a[0], F::scalar_add(v0, F::scalar_mul(w, v1))); + let m = F::scalar_mul(F::scalar_add(fe[0], fe[1]), F::scalar_add(ge[0], ge[1])); + a[1] = F::scalar_add(a[1], F::scalar_sub(F::scalar_sub(m, v0), v1)); + + // b += fe * go + fo * ge + let u0 = F::scalar_mul(fe[0], go_[0]); + let u1 = F::scalar_mul(fe[1], go_[1]); + let m1 = F::scalar_mul(F::scalar_add(fe[0], fe[1]), F::scalar_add(go_[0], go_[1])); + let p0 = F::scalar_mul(fo[0], ge[0]); + let p1 = F::scalar_mul(fo[1], ge[1]); + let m2 = F::scalar_mul(F::scalar_add(fo[0], fo[1]), F::scalar_add(ge[0], ge[1])); + + b[0] = F::scalar_add( + b[0], + F::scalar_add( + F::scalar_add(u0, F::scalar_mul(w, u1)), + F::scalar_add(p0, F::scalar_mul(w, p1)), + ), + ); + b[1] = F::scalar_add( + b[1], + F::scalar_add( + F::scalar_sub(F::scalar_sub(m1, u0), u1), + F::scalar_sub(F::scalar_sub(m2, p0), p1), + ), + ); + i += 2; + } + + (a, b) +} + +/// SoA ext3 inner product evaluate. +/// +/// Given `f` and `g` as ext3 elements in SoA layout (f_c0, f_c1, f_c2, g_c0, g_c1, g_c2), +/// computes the degree-2 round polynomial coefficients `(a, b)`: +/// a = Σ f[2i] * g[2i] (ext3 products) +/// b = Σ (f[2i] * g[2i+1] + f[2i+1] * g[2i]) (ext3 cross-terms) +/// +/// Returns `(a_components, b_components)` as `[u64; 3]` raw Montgomery values. +pub fn ext3_soa_product_evaluate>( + f_c0: &[u64], + f_c1: &[u64], + f_c2: &[u64], + g_c0: &[u64], + g_c1: &[u64], + g_c2: &[u64], + w: u64, +) -> ([u64; 3], [u64; 3]) { + let n = f_c0.len(); + debug_assert_eq!(n, f_c1.len()); + debug_assert_eq!(n, f_c2.len()); + debug_assert_eq!(n, g_c0.len()); + debug_assert_eq!(n, g_c1.len()); + debug_assert_eq!(n, g_c2.len()); + + let lanes = F::LANES; + // Each load_deinterleaved consumes 2*lanes u64s (one load_width). 2× unroll. + let load_width = 2 * lanes; + let step = 2 * load_width; // 4 * lanes + let aligned = (n / step) * step; + let w_vec = F::splat(w); + + let zero = F::splat(F::ZERO); + // Accumulators for a (3 components) and b (3 components) + let mut acc_a = [zero; 3]; + let mut acc_b = [zero; 3]; + + let mut i = 0; + while i < aligned { + unsafe { + for u in 0..2 { + let off = i + u * load_width; + let (fe0, fo0) = F::load_deinterleaved(f_c0.as_ptr().add(off)); + let (fe1, fo1) = F::load_deinterleaved(f_c1.as_ptr().add(off)); + let (fe2, fo2) = F::load_deinterleaved(f_c2.as_ptr().add(off)); + let (ge0, go0) = F::load_deinterleaved(g_c0.as_ptr().add(off)); + let (ge1, go1) = F::load_deinterleaved(g_c1.as_ptr().add(off)); + let (ge2, go2) = F::load_deinterleaved(g_c2.as_ptr().add(off)); + + // a += f_even * g_even (ext3 Karatsuba) + let prod_a = soa_ext3_mul::([fe0, fe1, fe2], [ge0, ge1, ge2], w_vec); + acc_a[0] = F::add(acc_a[0], prod_a[0]); + acc_a[1] = F::add(acc_a[1], prod_a[1]); + acc_a[2] = F::add(acc_a[2], prod_a[2]); + + // b += f_even * g_odd + f_odd * g_even + let prod_eg = soa_ext3_mul::([fe0, fe1, fe2], [go0, go1, go2], w_vec); + let prod_oe = soa_ext3_mul::([fo0, fo1, fo2], [ge0, ge1, ge2], w_vec); + acc_b[0] = F::add(acc_b[0], F::add(prod_eg[0], prod_oe[0])); + acc_b[1] = F::add(acc_b[1], F::add(prod_eg[1], prod_oe[1])); + acc_b[2] = F::add(acc_b[2], F::add(prod_eg[2], prod_oe[2])); + } + } + i += step; + } + + // Remaining SIMD vectors (one load_width at a time) + while i + load_width <= n { + unsafe { + let (fe0, fo0) = F::load_deinterleaved(f_c0.as_ptr().add(i)); + let (fe1, fo1) = F::load_deinterleaved(f_c1.as_ptr().add(i)); + let (fe2, fo2) = F::load_deinterleaved(f_c2.as_ptr().add(i)); + let (ge0, go0) = F::load_deinterleaved(g_c0.as_ptr().add(i)); + let (ge1, go1) = F::load_deinterleaved(g_c1.as_ptr().add(i)); + let (ge2, go2) = F::load_deinterleaved(g_c2.as_ptr().add(i)); + + let prod_a = soa_ext3_mul::([fe0, fe1, fe2], [ge0, ge1, ge2], w_vec); + acc_a[0] = F::add(acc_a[0], prod_a[0]); + acc_a[1] = F::add(acc_a[1], prod_a[1]); + acc_a[2] = F::add(acc_a[2], prod_a[2]); + + let prod_eg = soa_ext3_mul::([fe0, fe1, fe2], [go0, go1, go2], w_vec); + let prod_oe = soa_ext3_mul::([fo0, fo1, fo2], [ge0, ge1, ge2], w_vec); + acc_b[0] = F::add(acc_b[0], F::add(prod_eg[0], prod_oe[0])); + acc_b[1] = F::add(acc_b[1], F::add(prod_eg[1], prod_oe[1])); + acc_b[2] = F::add(acc_b[2], F::add(prod_eg[2], prod_oe[2])); + } + i += load_width; + } + + // Horizontal reduce + let mut buf = [F::ZERO; 32]; + let mut a = [F::ZERO; 3]; + let mut b = [F::ZERO; 3]; + + for c in 0..3 { + unsafe { F::store(buf.as_mut_ptr(), acc_a[c]) }; + for &v in buf.iter().take(lanes) { + a[c] = F::scalar_add(a[c], v); + } + unsafe { F::store(buf.as_mut_ptr(), acc_b[c]) }; + for &v in buf.iter().take(lanes) { + b[c] = F::scalar_add(b[c], v); + } + } + + // Scalar tail + while i + 1 < n { + let fe = [f_c0[i], f_c1[i], f_c2[i]]; + let fo = [f_c0[i + 1], f_c1[i + 1], f_c2[i + 1]]; + let ge = [g_c0[i], g_c1[i], g_c2[i]]; + let go_ = [g_c0[i + 1], g_c1[i + 1], g_c2[i + 1]]; + + let pa = scalar_ext3_mul::(fe, ge, w); + for c in 0..3 { + a[c] = F::scalar_add(a[c], pa[c]); + } + + let peg = scalar_ext3_mul::(fe, go_, w); + let poe = scalar_ext3_mul::(fo, ge, w); + for c in 0..3 { + b[c] = F::scalar_add(b[c], F::scalar_add(peg[c], poe[c])); + } + + i += 2; + } + + (a, b) +} + +/// Ext3 Karatsuba multiply for SIMD vectors in SoA layout. +/// 6 base muls + 2 w-muls + adds. +#[inline(always)] +fn soa_ext3_mul>( + a: [F::Packed; 3], + b: [F::Packed; 3], + w: F::Packed, +) -> [F::Packed; 3] { + let ad = F::mul(a[0], b[0]); + let be = F::mul(a[1], b[1]); + let cf = F::mul(a[2], b[2]); + + let x = F::sub( + F::sub(F::mul(F::add(a[1], a[2]), F::add(b[1], b[2])), be), + cf, + ); + let y = F::sub( + F::sub(F::mul(F::add(a[0], a[1]), F::add(b[0], b[1])), ad), + be, + ); + let z = F::add( + F::sub( + F::sub(F::mul(F::add(a[0], a[2]), F::add(b[0], b[2])), ad), + cf, + ), + be, + ); + + [F::add(ad, F::mul(w, x)), F::add(y, F::mul(w, cf)), z] +} + +/// Scalar ext3 Karatsuba multiply helper. +#[inline(always)] +fn scalar_ext3_mul>(a: [u64; 3], b: [u64; 3], w: u64) -> [u64; 3] { + let ad = F::scalar_mul(a[0], b[0]); + let be = F::scalar_mul(a[1], b[1]); + let cf = F::scalar_mul(a[2], b[2]); + + let x = F::scalar_sub( + F::scalar_sub( + F::scalar_mul(F::scalar_add(a[1], a[2]), F::scalar_add(b[1], b[2])), + be, + ), + cf, + ); + let y = F::scalar_sub( + F::scalar_sub( + F::scalar_mul(F::scalar_add(a[0], a[1]), F::scalar_add(b[0], b[1])), + ad, + ), + be, + ); + let z = F::scalar_add( + F::scalar_sub( + F::scalar_sub( + F::scalar_mul(F::scalar_add(a[0], a[2]), F::scalar_add(b[0], b[2])), + ad, + ), + cf, + ), + be, + ); + + [ + F::scalar_add(ad, F::scalar_mul(w, x)), + F::scalar_add(y, F::scalar_mul(w, cf)), + z, + ] +} + +#[cfg(test)] +#[cfg(any( + target_arch = "aarch64", + all(target_arch = "x86_64", target_feature = "avx512ifma") +))] +mod tests { + use super::*; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512ifma"))] + use crate::simd_fields::goldilocks::avx512::GoldilocksAvx512 as Backend; + #[cfg(target_arch = "aarch64")] + use crate::simd_fields::goldilocks::neon::GoldilocksNeon as Backend; + use crate::tests::F64; + use ark_ff::UniformRand; + use ark_std::test_rng; + + #[test] + fn test_reduce_and_evaluate_matches() { + use crate::reductions::pairwise; + + let mut rng = test_rng(); + let n = 1 << 16; + let evals_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let mut evals_raw: Vec = evals_ff.iter().map(|f| (*f).value).collect(); + + let challenge_ff = F64::rand(&mut rng); + let challenge_raw = challenge_ff.value; + + // Reference: reduce then evaluate + let mut expected_ff = evals_ff; + pairwise::reduce_evaluations(&mut expected_ff, challenge_ff); + let (expected_even, expected_odd) = pairwise::evaluate(&expected_ff); + + // Fused + let (fused_even, fused_odd, new_len) = + reduce_and_evaluate::(&mut evals_raw, challenge_raw); + + assert_eq!(new_len, n / 2); + assert_eq!(expected_even.value, fused_even, "fused even mismatch"); + assert_eq!(expected_odd.value, fused_odd, "fused odd mismatch"); + + // Also verify the reduce output matches + for i in 0..new_len { + assert_eq!( + expected_ff[i].value, evals_raw[i], + "reduce mismatch at index {}", + i + ); + } + } + + #[test] + fn test_reduce_and_evaluate_large() { + use crate::reductions::pairwise; + + let mut rng = test_rng(); + let n = 1 << 20; + let evals_ff: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let mut evals_raw: Vec = evals_ff.iter().map(|f| (*f).value).collect(); + + let challenge_ff = F64::rand(&mut rng); + let challenge_raw = challenge_ff.value; + + let mut expected_ff = evals_ff; + pairwise::reduce_evaluations(&mut expected_ff, challenge_ff); + let (expected_even, expected_odd) = pairwise::evaluate(&expected_ff); + + let (fused_even, fused_odd, _) = + reduce_and_evaluate::(&mut evals_raw, challenge_raw); + + assert_eq!(expected_even.value, fused_even, "large fused even mismatch"); + assert_eq!(expected_odd.value, fused_odd, "large fused odd mismatch"); + } +} diff --git a/src/streams/memory/core.rs b/src/streams/memory/core.rs index 19afad66..455a6036 100644 --- a/src/streams/memory/core.rs +++ b/src/streams/memory/core.rs @@ -1,39 +1,60 @@ -use crate::{order_strategy::OrderStrategy, streams::Stream}; +use crate::streams::Stream; use ark_ff::Field; -/* - * It's totally reasonable to use this when the evaluations table - * fits in memory (and yes, it's not so much a stream in this case) - */ +#[cfg(feature = "parallel")] +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +/// In-memory evaluation stream. #[derive(Debug, Clone)] pub struct MemoryStream { pub evaluations: Vec, } -pub fn reorder_vec(evaluations: Vec) -> Vec { - // abort if length not a power of two +/// Bit-reversal permutation: reorders evaluations from ascending (lexicographic) +/// to MSB (half-split) layout. +/// +/// `out[i] = src[bit_reverse(i, num_vars)]`. +/// +/// Uses `usize::reverse_bits` (hardware instruction on most targets). +/// Parallel via rayon above the threshold. +pub fn reorder_vec_msb(evaluations: Vec) -> Vec { assert!(!evaluations.is_empty() && evaluations.len().count_ones() == 1); let num_vars = evaluations.len().trailing_zeros() as usize; - let mut order = O::new(num_vars); - let mut evaluations_ordered = Vec::with_capacity(evaluations.len()); - for index in &mut order { - evaluations_ordered.push(evaluations[index]); + bit_reverse_reorder(evaluations, num_vars) +} + +const BIT_REVERSE_PARALLEL_THRESHOLD: usize = 1 << 17; + +#[inline] +fn bit_reverse_reorder(src: Vec, num_vars: usize) -> Vec { + let n = src.len(); + if num_vars == 0 { + return src; } - evaluations_ordered + let shift = usize::BITS - num_vars as u32; + + #[cfg(feature = "parallel")] + { + if n > BIT_REVERSE_PARALLEL_THRESHOLD { + return (0..n) + .into_par_iter() + .map(|i| src[i.reverse_bits() >> shift]) + .collect(); + } + } + + (0..n).map(|i| src[i.reverse_bits() >> shift]).collect() } impl MemoryStream { pub fn new(evaluations: Vec) -> Self { - // abort if length not a power of two assert!(!evaluations.is_empty() && evaluations.len().count_ones() == 1); - // return the MemoryStream instance Self { evaluations } } - pub fn new_from_lex(evaluations: Vec) -> Self { - // abort if length not a power of two - assert!(!evaluations.is_empty() && evaluations.len().count_ones() == 1); - Self::new(reorder_vec::(evaluations)) + + /// Construct from ascending (lex) order evaluations, reordering to MSB. + pub fn new_from_lex_msb(evaluations: Vec) -> Self { + Self::new(reorder_vec_msb(evaluations)) } } @@ -45,3 +66,31 @@ impl Stream for MemoryStream { self.evaluations.len().ilog2() as usize } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::F64; + use ark_ff::UniformRand; + use ark_std::test_rng; + + #[test] + fn msb_reorder_roundtrip() { + let mut rng = test_rng(); + for num_vars in [1usize, 2, 4, 8, 12] { + let n = 1usize << num_vars; + let input: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + // Double bit-reverse should be identity. + let once = reorder_vec_msb(input.clone()); + let twice = reorder_vec_msb(once); + assert_eq!(twice, input, "mismatch at num_vars={}", num_vars); + } + } + + #[test] + fn msb_num_vars_zero_edge_case() { + let input = vec![F64::from(42u64)]; + let got = reorder_vec_msb(input.clone()); + assert_eq!(got, input); + } +} diff --git a/src/streams/memory/mod.rs b/src/streams/memory/mod.rs index 67d38b97..cd230db2 100644 --- a/src/streams/memory/mod.rs +++ b/src/streams/memory/mod.rs @@ -1,3 +1,3 @@ mod core; -pub use core::{reorder_vec, MemoryStream}; +pub use core::{reorder_vec_msb, MemoryStream}; diff --git a/src/streams/mod.rs b/src/streams/mod.rs index 386204f2..475e335a 100644 --- a/src/streams/mod.rs +++ b/src/streams/mod.rs @@ -1,9 +1,7 @@ mod file; mod memory; mod stream; -mod stream_iterator; pub use file::FileStream; -pub use memory::{reorder_vec, MemoryStream}; +pub use memory::{reorder_vec_msb, MemoryStream}; pub use stream::{multivariate_claim, multivariate_product_claim, Stream}; -pub use stream_iterator::StreamIterator; diff --git a/src/streams/stream_iterator.rs b/src/streams/stream_iterator.rs deleted file mode 100644 index e819aeff..00000000 --- a/src/streams/stream_iterator.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::marker::PhantomData; - -use crate::{order_strategy::OrderStrategy, streams::Stream}; -use ark_ff::Field; - -pub struct StreamIterator, O: OrderStrategy> { - stream: S, - order: O, - _marker: PhantomData, -} - -impl, O: OrderStrategy> StreamIterator { - pub fn new(stream: S) -> Self { - let order = O::new(stream.num_variables()); - Self { - stream, - order, - _marker: PhantomData, - } - } - pub fn reset(&mut self) { - self.order = O::new(self.stream.num_variables()); - } -} - -impl, O: OrderStrategy> Iterator for StreamIterator { - type Item = F; - - fn next(&mut self) -> Option { - match self.order.next_index() { - Some(index) => Some(self.stream.evaluation(index)), - None => None, - } - } -} diff --git a/src/sumcheck_prover.rs b/src/sumcheck_prover.rs new file mode 100644 index 00000000..f4cc5a19 --- /dev/null +++ b/src/sumcheck_prover.rs @@ -0,0 +1,59 @@ +//! The [`SumcheckProver`] trait — the extension point for all sumcheck +//! prover strategies and polynomial shapes. +//! +//! Implementors define *how* the round polynomial is computed from the +//! prover's internal state. The protocol runner ([`crate::runner::sumcheck`]) +//! calls [`round()`](SumcheckProver::round) once per round; the caller +//! retains ownership of the prover and can inspect post-state after sumcheck +//! completes. + +extern crate alloc; +use crate::field::SumcheckField; +use alloc::vec::Vec; + +/// Prover side of the sum-check protocol (Thaler Proposition 4.1). +/// +/// # Lifecycle +/// +/// ```text +/// evals_0 = round(None) // round 0: compute g_0 from initial state +/// evals_1 = round(Some(r_0)) // round 1: fold with r_0, compute g_1 +/// ... +/// evals_{v-1} = round(Some(r_{v-2})) // round v-1: fold with r_{v-2}, compute g_{v-1} +/// finalize(r_{v-1}) // apply the last challenge +/// value = final_value() // g(r_0, ..., r_{v-1}) +/// ``` +/// +/// # Post-state +/// +/// Since the prover is passed as `&mut P`, the caller retains ownership +/// after sumcheck completes and can query prover-specific accessors: +/// +/// ```ignore +/// let proof = sumcheck(&mut prover, n, &mut t, |_, _| {}); +/// let (f_eval, g_eval) = prover.final_evaluations(); // prover-specific +/// ``` +pub trait SumcheckProver { + /// Maximum degree of the round polynomial in the current variable. + fn degree(&self) -> usize; + + /// Compute the round polynomial and advance state. + /// + /// Returns evaluations of g_j at `{0, 1, ..., degree()}`. + /// + /// - `challenge = None`: round 0 — compute from initial state. + /// - `challenge = Some(r)`: fold/update state with the previous round's + /// challenge `r`, then compute the next round polynomial. + fn round(&mut self, challenge: Option) -> Vec; + + /// Apply the final verifier challenge. + /// + /// Called once after the last round, before [`final_value()`](Self::final_value). + /// The prover folds its internal state with `last_challenge` so that + /// `final_value()` can return `g(r_1, ..., r_v)`. + fn finalize(&mut self, last_challenge: F); + + /// After all rounds and [`finalize()`](Self::finalize): the claimed + /// value `g(r_1, ..., r_v)`. + fn final_value(&self) -> F; +} diff --git a/src/tests/fields.rs b/src/tests/fields.rs index 503026c6..f9fb61f7 100644 --- a/src/tests/fields.rs +++ b/src/tests/fields.rs @@ -1,3 +1,6 @@ +use ark_ff::define_field; +use ark_ff::fields::models::cubic_extension::{CubicExtConfig, CubicExtField}; +use ark_ff::fields::models::quadratic_extension::{QuadExtConfig, QuadExtField}; use ark_ff::fields::{Fp128, Fp64, MontBackend, MontConfig}; #[derive(MontConfig)] @@ -18,14 +21,86 @@ pub type M31 = Fp64>; pub struct BabyBearConfig; pub type BabyBear = Fp64>; +// Goldilocks: q = 2^64 - 2^32 + 1 +// Primary type: SmallFp (optimal single-u64 Montgomery representation). +define_field!( + modulus = "18446744069414584321", + generator = "7", + name = F64, +); + +// Secondary type: Fp64 (for compatibility with code using MontConfig). +// Both F64 and FpF64 store a single u64 in Montgomery form — the SIMD backend +// works identically for either. #[derive(MontConfig)] -#[modulus = "18446744069414584321"] // q = 2^64 - 2^32 + 1 -#[generator = "2"] -pub struct F64Config; -pub type F64 = Fp64>; +#[modulus = "18446744069414584321"] +#[generator = "7"] +pub struct FpF64Config; +pub type FpF64 = Fp64>; + +// Degree-2 extension of Goldilocks: F64[X] / (X² - 7) +// NONRESIDUE = 7 (must be a non-square in F64). +pub struct F64Ext2Config; +impl QuadExtConfig for F64Ext2Config { + type BasePrimeField = F64; + type BaseField = F64; + type FrobCoeff = F64; + const DEGREE_OVER_BASE_PRIME_FIELD: usize = 2; + const NONRESIDUE: F64 = F64::from_raw(7); + // Frobenius coefficient: NONRESIDUE^((p-1)/2). For testing, -1 works + // for any non-square nonresidue (Euler criterion). + // Frobenius coefficients: [1, -1]. + // -1 mod P = P - 1 = 0xFFFF_FFFF_0000_0000 in Montgomery form. + // Actually, -1 in Montgomery form is mont(-1) = mont(P-1) = (P-1)*R mod P. + // For Goldilocks, R mod P = EPSILON = 0xFFFFFFFF. + // mont(P-1) = (P-1) * R mod P. Let's just use from_raw(P - 1)... no. + // from_raw takes a value already in Montgomery form. + // -1 in Montgomery form = R * (P-1) mod P = (-R) mod P = P - EPSILON = P - (2^32-1) + // = 0xFFFF_FFFF_0000_0001 - 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0002 + // Actually easier: just use the constant P - EPSILON. + const FROBENIUS_COEFF_C1: &'static [F64] = &[ + F64::from_raw(0xFFFF_FFFF), // mont(1) = R mod P = EPSILON + F64::from_raw(0xFFFF_FFFE_0000_0002), // mont(-1) = P - EPSILON + ]; + + fn mul_base_field_by_frob_coeff(fe: &mut Self::BaseField, power: usize) { + *fe *= &Self::FROBENIUS_COEFF_C1[power % 2]; + } +} +pub type F64Ext2 = QuadExtField; + +// Degree-3 extension of Goldilocks: F64[X] / (X³ - 7) +pub struct F64Ext3Config; +impl CubicExtConfig for F64Ext3Config { + type BasePrimeField = F64; + type BaseField = F64; + type FrobCoeff = F64; + const SQRT_PRECOMP: Option>> = None; + const DEGREE_OVER_BASE_PRIME_FIELD: usize = 3; + const NONRESIDUE: F64 = F64::from_raw(7); + // Frobenius coefficients for cubic extension. + // FROBENIUS_COEFF_C1[i] = NONRESIDUE^((p^i - 1) / 3) + // FROBENIUS_COEFF_C2[i] = NONRESIDUE^((2*(p^i - 1)) / 3) + // For testing purposes, we use identity (power 0) and compute the rest. + // Since p ≡ 1 mod 3 for Goldilocks, these exist. + // For simplicity, use [1, w^((p-1)/3), w^(2(p-1)/3)] but computing these + // requires modular exponentiation. For test-only usage, just provide placeholders + // that satisfy the trait — the sumcheck doesn't use Frobenius. + const FROBENIUS_COEFF_C1: &'static [F64] = &[F64::from_raw(0xFFFF_FFFF)]; // [1] + const FROBENIUS_COEFF_C2: &'static [F64] = &[F64::from_raw(0xFFFF_FFFF)]; // [1] + + fn mul_base_field_by_frob_coeff( + _c1: &mut Self::BaseField, + _c2: &mut Self::BaseField, + _power: usize, + ) { + // Frobenius not used in sumcheck — no-op for testing + } +} +pub type F64Ext3 = CubicExtField; #[derive(MontConfig)] -#[modulus = "143244528689204659050391023439224324689"] // q = 143244528689204659050391023439224324689 +#[modulus = "143244528689204659050391023439224324689"] #[generator = "2"] pub struct F128Config; pub type F128 = Fp128>; diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 2a0907d0..dd49bc7d 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,8 +1,7 @@ +#[allow(clippy::assign_op_pattern)] mod fields; mod streams; -pub mod multilinear; -pub mod multilinear_product; pub mod polynomials; -pub use fields::{BabyBear, F128, F19, F64, M31}; +pub use fields::{BabyBear, F64Ext2, F64Ext3, FpF64, F128, F19, F64, M31}; pub use streams::BenchStream; diff --git a/src/tests/multilinear/mod.rs b/src/tests/multilinear/mod.rs deleted file mode 100644 index ff53180b..00000000 --- a/src/tests/multilinear/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod provers; -mod sanity; - -pub use provers::{BasicProver, BasicProverConfig}; -pub use sanity::{pairwise_sanity_test, sanity_test, sanity_test_driver}; diff --git a/src/tests/multilinear/provers/basic/config.rs b/src/tests/multilinear/provers/basic/config.rs deleted file mode 100644 index 72ea4520..00000000 --- a/src/tests/multilinear/provers/basic/config.rs +++ /dev/null @@ -1,17 +0,0 @@ -use ark_ff::Field; -use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; -pub struct BasicProverConfig { - pub claim: F, - pub num_variables: usize, - pub p: SparsePolynomial, -} - -impl BasicProverConfig { - pub fn new(claim: F, num_variables: usize, p: SparsePolynomial) -> Self { - Self { - claim, - num_variables, - p, - } - } -} diff --git a/src/tests/multilinear/provers/basic/core.rs b/src/tests/multilinear/provers/basic/core.rs deleted file mode 100644 index 7bd24067..00000000 --- a/src/tests/multilinear/provers/basic/core.rs +++ /dev/null @@ -1,64 +0,0 @@ -use ark_ff::Field; -use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; - -use crate::{ - hypercube::Hypercube, messages::VerifierMessages, order_strategy::AscendingOrder, - tests::polynomials::Polynomial, -}; -pub struct BasicProver { - pub current_round: usize, - pub num_variables: usize, - pub p: SparsePolynomial, - pub verifier_messages: VerifierMessages, -} - -impl BasicProver { - pub fn compute_round(&self) -> (F, F) { - let mut m: (F, F) = (F::ZERO, F::ZERO); - for (_, b) in Hypercube::::new(self.num_variables - self.current_round - 1) - { - let partial_point: Vec = b - .to_vec_bool() - .into_iter() - .map(|bit: bool| -> F { - if bit { - F::ONE - } else { - F::ZERO - } - }) - .collect(); - let partial_point_zero: Vec = std::iter::once(F::ZERO) - .chain(partial_point.iter().cloned()) - .collect(); - let partial_point_one: Vec = std::iter::once(F::ONE) - .chain(partial_point.iter().cloned()) - .collect(); - let point_zero: Vec = self - .verifier_messages - .messages - .iter() - .cloned() - .chain(partial_point_zero.iter().cloned()) - .collect(); - let point_one: Vec = self - .verifier_messages - .messages - .iter() - .cloned() - .chain(partial_point_one.iter().cloned()) - .collect(); - let p_zero = self.p.evaluate(point_zero.clone()).unwrap(); - let p_one = self.p.evaluate(point_one.clone()).unwrap(); - m.0 += p_zero; - m.1 += p_one; - } - m - } - pub fn is_initial_round(&self) -> bool { - self.current_round == 0 - } - pub fn total_rounds(&self) -> usize { - self.num_variables - } -} diff --git a/src/tests/multilinear/provers/basic/mod.rs b/src/tests/multilinear/provers/basic/mod.rs deleted file mode 100644 index 49374819..00000000 --- a/src/tests/multilinear/provers/basic/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::BasicProverConfig; -pub use core::BasicProver; diff --git a/src/tests/multilinear/provers/basic/prover.rs b/src/tests/multilinear/provers/basic/prover.rs deleted file mode 100644 index a5ceea2a..00000000 --- a/src/tests/multilinear/provers/basic/prover.rs +++ /dev/null @@ -1,138 +0,0 @@ -use ark_ff::Field; - -use crate::{ - messages::VerifierMessages, - prover::Prover, - tests::multilinear::{BasicProver, BasicProverConfig}, -}; - -impl Prover for BasicProver { - type ProverConfig = BasicProverConfig; - type ProverMessage = Option<(F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - Self { - current_round: 0, - num_variables: prover_config.num_variables, - p: prover_config.p, - verifier_messages: VerifierMessages::new(&vec![]), - } - } - - fn next_message(&mut self, verifier_message: Self::VerifierMessage) -> Self::ProverMessage { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if !self.is_initial_round() { - self.verifier_messages - .receive_message(verifier_message.unwrap()); - } - - let sums: (F, F) = self.compute_round(); - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial sums - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::prover::Prover; - use crate::tests::polynomials::Polynomial; - use crate::Sumcheck; - use crate::{ - streams::{multivariate_claim, MemoryStream}, - tests::{ - multilinear::{sanity_test_driver, BasicProver, BasicProverConfig}, - polynomials::{three_variable_polynomial, three_variable_polynomial_evaluations}, - F19, - }, - }; - use ark_poly::{ - multivariate::{self, SparsePolynomial, SparseTerm, Term}, - DenseMVPolynomial, - }; - - #[test] - fn sanity() { - let p: SparsePolynomial = three_variable_polynomial(); - let s = MemoryStream::::new(three_variable_polynomial_evaluations()); - let config = BasicProverConfig::new(multivariate_claim(s.clone()), 3, p); - let mut prover = BasicProver::new(config); - sanity_test_driver(&mut prover); - } - - #[test] - fn sumcheck_1_variable() { - // 3 * x_0 + 1 - let x_squared_plus_1 = multivariate::SparsePolynomial::from_coefficients_slice( - 1, - &[ - (F19::from(3u32), multivariate::SparseTerm::new(vec![(0, 1)])), - (F19::from(1u32), multivariate::SparseTerm::new(vec![])), - ], - ); - let s = MemoryStream::::new(x_squared_plus_1.to_evaluations()); - let config = BasicProverConfig::new(multivariate_claim(s.clone()), 1, x_squared_plus_1); - let mut prover = BasicProver::new(config); - let _transcript = Sumcheck::::prove::, BasicProver>( - &mut prover, - &mut ark_std::test_rng(), - ); - // println!("transcript: {:?}", transcript); - // round 0 - // point: [0] -> 1 - // point: [1] -> 4 - // g0 = 3*x + 1 (it's the original polynomial so this is not useful for anything) - // Sumcheck { prover_messages: [(1, 4)], verifier_messages: [], is_accepted: true } - } - - #[test] - fn sumcheck_2_variables() { - // 3*x_0*x_1 + 5*x_0 + 1 - let x_zero_squared_x_one_plus_3_x_one_plus_1 = - multivariate::SparsePolynomial::from_coefficients_slice( - 2, - &[ - ( - F19::from(3u32), - multivariate::SparseTerm::new(vec![(0, 1), (1, 1)]), - ), - (F19::from(5u32), multivariate::SparseTerm::new(vec![(0, 1)])), - (F19::from(1u32), multivariate::SparseTerm::new(vec![])), - ], - ); - let s = MemoryStream::::new(x_zero_squared_x_one_plus_3_x_one_plus_1.to_evaluations()); - let config = BasicProverConfig::new( - multivariate_claim(s.clone()), - 2, - x_zero_squared_x_one_plus_3_x_one_plus_1, - ); - let mut prover = BasicProver::new(config); - let _transcript = Sumcheck::::prove::, BasicProver>( - &mut prover, - &mut ark_std::test_rng(), - ); - // println!("transcript: {:?}", transcript); - - // round 0 - // point: [0, 0] -> 1 - // point: [1, 0] -> 6 - // point: [0, 1] -> 1 - // point: [1, 1] -> 9 - // g0: 13*x + 2 - - // round 1 - // point: [2, 0] -> 11 - // point: [2, 1] -> 17 - // g1: 6*x + 11 - - // Sumcheck { prover_messages: [(2, 15), (11, 17)], verifier_messages: [2], is_accepted: true } - } -} diff --git a/src/tests/multilinear/provers/mod.rs b/src/tests/multilinear/provers/mod.rs deleted file mode 100644 index b60ac1a3..00000000 --- a/src/tests/multilinear/provers/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod basic; - -pub use basic::{BasicProver, BasicProverConfig}; diff --git a/src/tests/multilinear/sanity.rs b/src/tests/multilinear/sanity.rs deleted file mode 100644 index b073935b..00000000 --- a/src/tests/multilinear/sanity.rs +++ /dev/null @@ -1,196 +0,0 @@ -use ark_ff::Field; - -use crate::{ - prover::{Prover, ProverConfig}, - streams::{MemoryStream, Stream}, - tests::polynomials::three_variable_polynomial_evaluations, -}; - -pub fn multilinear_round_sanity(p: &mut P, message: Option, eval_0: F, eval_1: F) -where - F: Field, - P: Prover, ProverMessage = Option<(F, F)>>, -{ - let round = p.next_message(message).unwrap(); - assert_eq!(round.0, eval_0, "g0 should evaluate correctly",); - assert_eq!(round.1, eval_1, "g1 should evaluate correctly",); -} - -pub fn sanity_test() -where - F: Field, - S: Stream + From>, - P: Prover, ProverMessage = Option<(F, F)>>, - P::ProverConfig: ProverConfig, -{ - let s: S = MemoryStream::new(three_variable_polynomial_evaluations()).into(); - let mut p = P::new(ProverConfig::default(3, s)); - /* - * Zeroth Round: All variables are free - * - * Evaluations at different input points: - * (0,0,0) → 0 - * (0,0,1) → 0 - * (0,1,0) → 13 - * (0,1,1) → 1 - * ----------------- - * Sum g₀(0) ≡ 14 - * - * (1,0,0) → 2 - * (1,0,1) → 2 - * (1,1,0) → 0 - * (1,1,1) → 7 - * ----------------- - * Sum g₀(1) ≡ 11 - */ - multilinear_round_sanity::(&mut p, None, F::from(14_u32), F::from(11_u32)); - /* - * First Round: x₀ fixed to 3 - * - * Evaluations at different input points: - * (3,0,1) → 6 - * (3,0,0) → 6 - * ----------------- - * Sum g₁(0) ≡ 12 - * - * (3,1,1) → 38 ≡ 0 (mod 19) - * (3,1,0) → 31 ≡ 12 (mod 19) - * ----------------- - * Sum g₁(1) ≡ 12 - */ - multilinear_round_sanity::( - &mut p, - Some(F::from(3_u32)), - F::from(12_u32), - F::from(12_u32), - ); - /* - * Last Round: x₁ fixed to 4 - * - * Evaluations at different input points: - * (3,4,0) → 108 ≡ 11 (mod 19) - * ----------------- - * Sum g(0) ≡ 11 - * - * (3,4,1) → 134 ≡ 1 (mod 19) - * ----------------- - * Sum g(1) ≡ 1 - */ - multilinear_round_sanity::( - &mut p, - Some(F::from(4_u32)), - F::from(11_u32), - F::from(1_u32), - ); -} - -pub fn pairwise_sanity_test() -where - F: Field, - S: Stream + From>, - P: Prover, ProverMessage = Option<(F, F)>>, - P::ProverConfig: ProverConfig, -{ - let s: S = MemoryStream::new(three_variable_polynomial_evaluations()).into(); - let mut p = P::new(ProverConfig::default(3, s)); - /* - * Zeroth Round: All variables are free - * - * Evaluations at different input points: - * (0,0,0) → 0 - * (0,0,1) → 0 - * (0,1,0) → 13 - * (0,1,1) → 1 - * (1,0,0) → 2 - * (1,0,1) → 2 - * (1,1,0) → 0 - * (1,1,1) → 7 - * - * Sum evens: g₀(0) ≡ 15 - * Sum odds: g₀(1) ≡ 10 - */ - multilinear_round_sanity::(&mut p, None, F::from(15_u32), F::from(10_u32)); - /* - * First Round: x₀ fixed to 3 - * - * Evaluations at different input points (adjacent points compressed): - * (3,0,0) → 0 - * (3,0,1) → 15 - * (3,1,0) → 2 - * (3,1,1) → 2 - * - * Sum evens: g₀(0) ≡ 2 - * Sum odds: g₀(1) ≡ 17 - */ - multilinear_round_sanity::( - &mut p, - Some(F::from(3_u32)), - F::from(2_u32), - F::from(17_u32), - ); - /* - * Last Round: x₁ fixed to 4 - * - * Evaluations at different input points: - * (3,4,0) → 3 - * (3,4,1) → 2 - * - * Sum evens: g₀(0) ≡ 3 - * Sum odds: g₀(1) ≡ 2 - */ - multilinear_round_sanity::(&mut p, Some(F::from(4_u32)), F::from(3_u32), F::from(2_u32)); -} - -pub fn sanity_test_driver(p: &mut P) -where - F: Field, - P: Prover, ProverMessage = Option<(F, F)>>, -{ - /* - * Zeroth Round: All variables are free - * - * Evaluations at different input points: - * (0,0,0) → 0 - * (0,0,1) → 0 - * (0,1,0) → 13 - * (0,1,1) → 1 - * ----------------- - * Sum g₀(0) ≡ 14 - * - * (1,0,0) → 2 - * (1,0,1) → 2 - * (1,1,0) → 0 - * (1,1,1) → 7 - * ----------------- - * Sum g₀(1) ≡ 11 - */ - multilinear_round_sanity::(p, None, F::from(14_u32), F::from(11_u32)); - /* - * First Round: x₀ fixed to 3 - * - * Evaluations at different input points: - * (3,0,1) → 6 - * (3,0,0) → 6 - * ----------------- - * Sum g₁(0) ≡ 12 - * - * (3,1,1) → 38 ≡ 0 (mod 19) - * (3,1,0) → 31 ≡ 12 (mod 19) - * ----------------- - * Sum g₁(1) ≡ 12 - */ - multilinear_round_sanity::(p, Some(F::from(3_u32)), F::from(12_u32), F::from(12_u32)); - /* - * Last Round: x₁ fixed to 4 - * - * Evaluations at different input points: - * (3,4,0) → 108 ≡ 11 (mod 19) - * ----------------- - * Sum g(0) ≡ 11 - * - * (3,4,1) → 134 ≡ 1 (mod 19) - * ----------------- - * Sum g(1) ≡ 1 - */ - multilinear_round_sanity::(p, Some(F::from(4_u32)), F::from(11_u32), F::from(1_u32)); -} diff --git a/src/tests/multilinear_product/consistency.rs b/src/tests/multilinear_product/consistency.rs deleted file mode 100644 index 52900f02..00000000 --- a/src/tests/multilinear_product/consistency.rs +++ /dev/null @@ -1,56 +0,0 @@ -use ark_ff::Field; -use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; - -use crate::{ - prover::{ProductProverConfig, Prover}, - streams::{multivariate_product_claim, Stream}, - tests::{ - multilinear_product::provers::basic::{BasicProductProver, BasicProductProverConfig}, - polynomials::Polynomial, - BenchStream, - }, - ProductSumcheck, -}; - -pub fn consistency_test() -where - F: Field, - S: Stream + From> + Clone, - P: Prover, ProverMessage = Option<(F, F, F)>>, - P::ProverConfig: ProductProverConfig, -{ - // get a stream - let num_variables = 16; - let s: S = BenchStream::new(num_variables).into(); - let claim = multivariate_product_claim(vec![s.clone(), s.clone()]); - - // get the sanity prover - let s_evaluations: Vec = (0..1 << num_variables).map(|i| s.evaluation(i)).collect(); - let p: SparsePolynomial = - as Polynomial>::from_hypercube_evaluations( - s_evaluations.clone(), - ); - let mut sanity_prover = BasicProductProver::::new(BasicProductProverConfig::new( - claim, - num_variables, - p.clone(), - p, - )); - - // prove - let prover_transcript = ProductSumcheck::::prove::, P>( - &mut P::new(ProductProverConfig::default( - num_variables, - vec![s.clone(), s], - )), - &mut ark_std::test_rng(), - ); - - let sanity_prover_transcript = ProductSumcheck::::prove::< - BenchStream, - BasicProductProver, - >(&mut sanity_prover, &mut ark_std::test_rng()); - - // ensure the transcript is identical - assert_eq!(prover_transcript, sanity_prover_transcript); -} diff --git a/src/tests/multilinear_product/mod.rs b/src/tests/multilinear_product/mod.rs deleted file mode 100644 index d1aa77d1..00000000 --- a/src/tests/multilinear_product/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod consistency; -mod provers; -mod sanity; - -pub use consistency::consistency_test; -pub use provers::basic::{ - BasicProductProver, BasicProductProverConfig, ProductProverPolynomialConfig, -}; -pub use sanity::{sanity_test, sanity_test_driver}; diff --git a/src/tests/multilinear_product/provers/basic/config.rs b/src/tests/multilinear_product/provers/basic/config.rs deleted file mode 100644 index 30ec47cd..00000000 --- a/src/tests/multilinear_product/provers/basic/config.rs +++ /dev/null @@ -1,50 +0,0 @@ -use ark_ff::Field; -use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; - -pub trait ProductProverPolynomialConfig { - fn default( - claim: F, - num_variables: usize, - p: SparsePolynomial, - q: SparsePolynomial, - ) -> Self; -} - -pub struct BasicProductProverConfig { - pub claim: F, - pub num_variables: usize, - pub p: SparsePolynomial, - pub q: SparsePolynomial, -} - -impl BasicProductProverConfig { - pub fn new( - claim: F, - num_variables: usize, - p: SparsePolynomial, - q: SparsePolynomial, - ) -> Self { - Self { - claim, - num_variables, - p, - q, - } - } -} - -impl ProductProverPolynomialConfig for BasicProductProverConfig { - fn default( - claim: F, - num_variables: usize, - p: SparsePolynomial, - q: SparsePolynomial, - ) -> Self { - Self { - claim, - num_variables, - p, - q, - } - } -} diff --git a/src/tests/multilinear_product/provers/basic/core.rs b/src/tests/multilinear_product/provers/basic/core.rs deleted file mode 100644 index 3b78d3fa..00000000 --- a/src/tests/multilinear_product/provers/basic/core.rs +++ /dev/null @@ -1,73 +0,0 @@ -use ark_ff::Field; -use ark_poly::multivariate::{SparsePolynomial, SparseTerm}; - -use crate::{ - hypercube::Hypercube, messages::VerifierMessages, order_strategy::GraycodeOrder, - tests::polynomials::Polynomial, -}; -pub struct BasicProductProver { - pub current_round: usize, - pub inverse_four: F, - pub num_variables: usize, - pub p: SparsePolynomial, - pub q: SparsePolynomial, - pub verifier_messages: VerifierMessages, -} - -impl BasicProductProver { - pub fn compute_round(&self) -> (F, F, F) { - let mut m: ((F, F), (F, F)) = ((F::ZERO, F::ZERO), (F::ZERO, F::ZERO)); - for (_, b) in Hypercube::::new(self.num_variables - self.current_round - 1) { - let partial_point: Vec = b - .to_vec_bool() - .into_iter() - .map(|bit: bool| -> F { - if bit { - F::ONE - } else { - F::ZERO - } - }) - .collect(); - let partial_point_zero: Vec = std::iter::once(F::ZERO) - .chain(partial_point.iter().cloned()) - .collect(); - let partial_point_one: Vec = std::iter::once(F::ONE) - .chain(partial_point.iter().cloned()) - .collect(); - let point_zero: Vec = self - .verifier_messages - .messages - .iter() - .cloned() - .chain(partial_point_zero.iter().cloned()) - .collect(); - let point_one: Vec = self - .verifier_messages - .messages - .iter() - .cloned() - .chain(partial_point_one.iter().cloned()) - .collect(); - let p_zero = self.p.evaluate(point_zero.clone()).unwrap(); - let p_one = self.p.evaluate(point_one.clone()).unwrap(); - let q_zero = self.q.evaluate(point_zero.clone()).unwrap(); - let q_one = self.q.evaluate(point_one.clone()).unwrap(); - m.0 .0 += p_zero * q_zero; - m.1 .1 += p_one * q_one; - m.0 .1 += p_zero * q_one; - m.1 .0 += p_one * q_zero; - } - ( - m.0 .0, - m.1 .1, - (F::ONE / F::from(4_u32)) * (m.0 .0 + m.1 .1 + m.0 .1 + m.1 .0), - ) - } - pub fn is_initial_round(&self) -> bool { - self.current_round == 0 - } - pub fn total_rounds(&self) -> usize { - self.num_variables - } -} diff --git a/src/tests/multilinear_product/provers/basic/mod.rs b/src/tests/multilinear_product/provers/basic/mod.rs deleted file mode 100644 index 644e807c..00000000 --- a/src/tests/multilinear_product/provers/basic/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod config; -mod core; -mod prover; - -pub use config::BasicProductProverConfig; -pub use config::ProductProverPolynomialConfig; -pub use core::BasicProductProver; diff --git a/src/tests/multilinear_product/provers/basic/prover.rs b/src/tests/multilinear_product/provers/basic/prover.rs deleted file mode 100644 index 95d255d2..00000000 --- a/src/tests/multilinear_product/provers/basic/prover.rs +++ /dev/null @@ -1,69 +0,0 @@ -use ark_ff::Field; - -use crate::{ - messages::VerifierMessages, - prover::Prover, - tests::multilinear_product::{BasicProductProver, BasicProductProverConfig}, -}; - -impl Prover for BasicProductProver { - type ProverConfig = BasicProductProverConfig; - type ProverMessage = Option<(F, F, F)>; - type VerifierMessage = Option; - - fn new(prover_config: Self::ProverConfig) -> Self { - Self { - current_round: 0, - inverse_four: F::from(4_u32).inverse().unwrap(), - num_variables: prover_config.num_variables, - p: prover_config.p, - q: prover_config.q, - verifier_messages: VerifierMessages::new(&vec![]), - } - } - - fn next_message(&mut self, verifier_message: Self::VerifierMessage) -> Self::ProverMessage { - // Ensure the current round is within bounds - if self.current_round >= self.total_rounds() { - return None; - } - - if !self.is_initial_round() { - self.verifier_messages - .receive_message(verifier_message.unwrap()); - } - - let sums: (F, F, F) = self.compute_round(); - - // Increment the round counter - self.current_round += 1; - - // Return the computed polynomial sums - Some(sums) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - prover::Prover, - streams::{multivariate_product_claim, MemoryStream}, - tests::{ - multilinear_product::{ - sanity_test_driver, BasicProductProver, BasicProductProverConfig, - }, - polynomials::{four_variable_polynomial, four_variable_polynomial_evaluations}, - F19, - }, - }; - #[test] - fn sumcheck() { - let s = MemoryStream::::new(four_variable_polynomial_evaluations()); - sanity_test_driver(&mut BasicProductProver::new(BasicProductProverConfig::new( - multivariate_product_claim(vec![s.clone(), s]), - 4, - four_variable_polynomial(), - four_variable_polynomial(), - ))); - } -} diff --git a/src/tests/multilinear_product/provers/mod.rs b/src/tests/multilinear_product/provers/mod.rs deleted file mode 100644 index 38883ee0..00000000 --- a/src/tests/multilinear_product/provers/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod basic; diff --git a/src/tests/multilinear_product/sanity.rs b/src/tests/multilinear_product/sanity.rs deleted file mode 100644 index d9f4e62a..00000000 --- a/src/tests/multilinear_product/sanity.rs +++ /dev/null @@ -1,145 +0,0 @@ -use ark_ff::Field; - -use crate::{ - prover::{ProductProverConfig, Prover}, - streams::{MemoryStream, Stream}, - tests::polynomials::four_variable_polynomial_evaluations, -}; - -fn multilinear_product_round_sanity( - round_num: usize, - p: &mut P, - message: Option, - eval_0: F, - eval_1: F, -) where - F: Field, - P: Prover, ProverMessage = Option<(F, F, F)>>, -{ - let round = p.next_message(message).unwrap(); - assert_eq!( - round.0, eval_0, - "g0 should evaluate correctly round {}", - round_num - ); - assert_eq!( - round.1, eval_1, - "g1 should evaluate correctly round {}", - round_num - ); -} - -pub fn sanity_test_driver(p: &mut P) -where - F: Field, - P: Prover, ProverMessage = Option<(F, F, F)>>, -{ - /* - * Zeroth Round: - * - * Evaluations: - * 0000 → 0 * 0 = 0 - * 0001 → 1 * 1 = 1 - * 0010 → 0 * 0 = 0 - * 0011 → 1 * 1 = 1 - * 0100 → 13 * 13 = 17 - * 0101 → 14 * 14 = 6 - * 0110 → 1 * 1 = 1 - * 0111 → 2 * 2 = 4 - * ---------------------- - * Sum g₀(0) = 11 - * - * 1000 → 2 * 2 = 4 - * 1001 → 3 * 3 = 9 - * 1010 → 2 * 2 = 4 - * 1011 → 3 * 3 = 9 - * 1100 → 0 * 0 = 0 - * 1101 → 1 * 1 = 1 - * 1110 → 7 * 7 = 11 - * 1111 → 8 * 8 = 7 - * ---------------------- - * Sum g₀(1) = 7 - */ - multilinear_product_round_sanity::(0, p, None, F::from(11_u32), F::from(7_u32)); - /* - * First Round: x₀ fixed to 3 - * - * Evaluations for g₁(0): - * 3000 → (0 * 17) + (2 * 3) = 6 * 6 = 17 - * 3001 → (1 * 17) + (3 * 3) = 7 * 7 = 11 - * 3010 → (0 * 17) + (2 * 3) = 6 * 6 = 17 - * 3011 → (1 * 17) + (3 * 3) = 7 * 7 = 11 - * ------------------------------- - * Sum g₁(0) = 18 - * - * Evaluations for g₁(1): - * 3100 → (13 * 17) + (0 * 3) = 12 * 12 = 11 - * 3101 → (14 * 17) + (1 * 3) = 13 * 13 = 17 - * 3110 → (1 * 17) + (7 * 3) = 0 * 0 = 0 - * 3111 → (2 * 17) + (8 * 3) = 1 * 1 = 1 - * ------------------------------- - * Sum g₁(1) = 10 - */ - multilinear_product_round_sanity::( - 1, - p, - Some(F::from(3_u32)), - F::from(18_u32), - F::from(10_u32), - ); - /* - * Second Round: x₁ fixed to 4 - * - * Evaluations for g₂(0): - * 3400 → (6 * 16) + (12 * 4) = 11 * 11 = 7 - * 3401 → (7 * 16) + (13 * 4) = 12 * 12 = 11 - * ------------------------------- - * Sum g₂(0) = 18 - * - * Evaluations for g₂(1): - * 3410 → (6 * 16) + (0 * 4) = 1 * 1 = 1 - * 3411 → (7 * 16) + (1 * 4) = 2 * 2 = 4 - * ------------------------------- - * Sum g₂(1) = 5 - */ - multilinear_product_round_sanity::( - 2, - p, - Some(F::from(4_u32)), - F::from(18_u32), - F::from(5_u32), - ); - /* - * Last Round: x₂ fixed to 7 - * - * Evaluations for g₃(0): - * 3470 → (11 * 13) + (1 * 7) = 17 * 17 = 4 - * ------------------------------- - * Sum g₃(0) = 4 - * - * Evaluations for g₃(1): - * 3471 → (12 * 13) + (2 * 7) = 18 * 18 = 1 - * ------------------------------- - * Sum g₃(1) = 1 - */ - multilinear_product_round_sanity::( - 3, - p, - Some(F::from(7_u32)), - F::from(4_u32), - F::from(1_u32), - ); -} - -pub fn sanity_test() -where - F: Field, - S: Stream + From>, - P: Prover, ProverMessage = Option<(F, F, F)>>, - P::ProverConfig: ProductProverConfig, -{ - let s_p: S = MemoryStream::new(four_variable_polynomial_evaluations()).into(); - let s_q: S = MemoryStream::new(four_variable_polynomial_evaluations()).into(); - let mut p = P::new(ProductProverConfig::default(4, vec![s_p, s_q])); - sanity_test_driver(&mut p); -} diff --git a/src/tests/polynomials.rs b/src/tests/polynomials.rs index 762df354..59c29e21 100644 --- a/src/tests/polynomials.rs +++ b/src/tests/polynomials.rs @@ -1,7 +1,3 @@ -use crate::{ - hypercube::{Hypercube, HypercubeMember}, - order_strategy::GraycodeOrder, -}; use ark_ff::Field; use ark_poly::{ multivariate::{self, SparsePolynomial, SparseTerm, Term}, @@ -9,7 +5,7 @@ use ark_poly::{ }; /* - * These are two small polynomials to use for sanity checking. + * Small polynomials for sanity checking. */ pub fn three_variable_polynomial() -> SparsePolynomial { @@ -32,7 +28,6 @@ pub fn three_variable_polynomial() -> SparsePolynomial } pub fn three_variable_polynomial_evaluations() -> Vec { - // 4*x_1*x_2 + 7*x_2*x_3 + 2*x_1 + 13*x_2 three_variable_polynomial().to_evaluations() } @@ -57,37 +52,23 @@ pub fn four_variable_polynomial() -> SparsePolynomial { } pub fn four_variable_polynomial_evaluations() -> Vec { - // 4*x_1*x_2 + 7*x_2*x_3 + 2*x_1 + 13*x_2 + 1x_4 four_variable_polynomial().to_evaluations() } /* - * Below here, we "extend" multivariate::SparsePolynomial so that we can - * get evaluations over the boolean hypercube (it's not so important it's just handy for testing) - * - * The idea comes from here: https://github.com/montekki/thaler-study/blob/master/sum-check-protocol/src/lib.rs + * Extension trait to evaluate multivariate sparse polynomials on the + * Boolean hypercube and convert between evaluation and coefficient form. */ pub trait Polynomial { - // Evaluates the polynomial at the provided point (expressed as a vector of field elements) fn evaluate(&self, point: Vec) -> Option; - - // Evaluates the polynomial at the provided point (expressed as a hypercube member) - // using the given number of variables. - fn evaluate_from_hypercube(&self, num_vars: usize, point: HypercubeMember) -> Option; - - // Converts the polynomial into a vector containing evaluations at every - // point of the hypercube. fn to_evaluations(&self) -> Vec; - - // take the evaluations table and give back a sparsepolynomial fn from_hypercube_evaluations(evaluations: Vec) -> SparsePolynomial; } impl Polynomial for SparsePolynomial { fn evaluate(&self, point: Vec) -> Option { assert_eq!(DenseMVPolynomial::::num_vars(self), point.len()); - // Compute the evaluation by summing the contributions of each term. let mut result = F::ZERO; for (coefficient, term) in self.terms().iter() { result += term.evaluate(&point) * coefficient; @@ -95,40 +76,27 @@ impl Polynomial for SparsePolynomial { Some(result) } - fn evaluate_from_hypercube(&self, num_vars: usize, point: HypercubeMember) -> Option { - // Convert the boolean representation into field elements. - let mut field_values: Vec = Vec::with_capacity(num_vars); - for bit in point { - field_values.push(if bit { F::ONE } else { F::ZERO }); - } - - // Compute the evaluation by summing the contributions of each term. - let mut result = F::ZERO; - for (coefficient, term) in self.terms().iter() { - result += term.evaluate(&field_values) * coefficient; - } - - Some(result) - } - fn to_evaluations(&self) -> Vec { let num_vars = DenseMVPolynomial::::num_vars(self); - let total_points = Hypercube::::stop_value(num_vars); + let total_points = 1usize << num_vars; let mut evaluations = Vec::with_capacity(total_points); - // Iterate through each index of the hypercube. for index in 0..total_points { - let point = HypercubeMember::new(num_vars, index); - let value = Self::evaluate_from_hypercube(self, num_vars, point).unwrap(); - evaluations.push(value); + // Convert index bits to field elements. + let point: Vec = (0..num_vars) + .map(|j| if index >> j & 1 == 1 { F::ONE } else { F::ZERO }) + .collect(); + let mut val = F::ZERO; + for (coefficient, term) in self.terms().iter() { + val += term.evaluate(&point) * coefficient; + } + evaluations.push(val); } evaluations } - // TODO (z-tech): this works but it's super slow fn from_hypercube_evaluations(mut evaluations: Vec) -> SparsePolynomial { - // Ensure that the evaluations vector length is a power of two. assert!( evaluations.len().is_power_of_two(), "evaluations len must be a power of two" @@ -136,18 +104,8 @@ impl Polynomial for SparsePolynomial { let num_vars: usize = evaluations.len().ilog2() as usize; let n = evaluations.len(); - // In-place bit reversal permutation: - // If the evaluations were produced with the highest-index variable corresponding to the LSB, - // we need to swap elements so that the i-th bit corresponds to variable x_i. - for i in 0_usize..n { - // Reverse the lower `num_vars` bits of i. - let j = i.reverse_bits() >> (usize::BITS - num_vars as u32); - if i < j { - evaluations.swap(i, j); - } - } - - // Perform in-place Möbius inversion on `evaluations` (now in standard binary order). + // Evaluations are in ascending (standard binary) order — no reorder needed. + // In-place Mobius inversion. for i in 0..num_vars { for mask in 0..n { if mask & (1 << i) != 0 { @@ -156,7 +114,7 @@ impl Polynomial for SparsePolynomial { } } - // Build the sparse polynomial representation from the nonzero coefficients. + // Build sparse polynomial from nonzero coefficients. let mut terms = Vec::new(); for (mask, evaluation) in evaluations.iter().enumerate() { if evaluations[mask] != F::zero() { @@ -188,7 +146,6 @@ mod tests { #[test] fn to_evaluations_from_evaluations_sanity() { - // we should get back the same polynomial let p1: SparsePolynomial = four_variable_polynomial::(); let p1_evaluations: Vec = p1.to_evaluations(); assert_eq!( @@ -198,7 +155,6 @@ mod tests { ) ); - // we should get back the same evaluations let num_variables: usize = 16; let s: BenchStream = BenchStream::new(num_variables); let hypercube_len: usize = 2usize.pow(num_variables as u32); diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index b434c4e5..c473cddc 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -1,8 +1,12 @@ +#[cfg(feature = "arkworks")] mod sanity; +#[cfg(feature = "arkworks")] mod spongefish; #[allow(clippy::module_inception)] mod transcript; -pub use sanity::SanityTranscript; +#[cfg(feature = "arkworks")] +pub use sanity::{SanityTranscript, TestTranscript}; +#[cfg(feature = "arkworks")] pub use spongefish::SpongefishTranscript; -pub use transcript::Transcript; +pub use transcript::{ProverTranscript, VerifierTranscript}; diff --git a/src/transcript/sanity.rs b/src/transcript/sanity.rs index 7cff8b6f..2e2b70db 100644 --- a/src/transcript/sanity.rs +++ b/src/transcript/sanity.rs @@ -1,29 +1,52 @@ use ark_ff::Field; use ark_std::rand::Rng; -use crate::transcript::Transcript; +use crate::transcript::{ProverTranscript, VerifierTranscript}; +/// Test transcript: sends are no-ops, receives return `Ok(random)`, +/// challenges return random values from the RNG. +/// +/// Implements both [`ProverTranscript`] and [`VerifierTranscript`] for +/// convenience in tests that run prover + verifier with the same RNG. #[derive(Debug)] -pub struct SanityTranscript<'a, R> { +pub struct TestTranscript<'a, R> { pub rng: &'a mut R, } -impl<'a, R> SanityTranscript<'a, R> { +impl<'a, R> TestTranscript<'a, R> { pub fn new(rng: &'a mut R) -> Self { Self { rng } } } -impl<'a, F, R> Transcript for SanityTranscript<'a, R> +impl<'a, F, R> ProverTranscript for TestTranscript<'a, R> where F: Field, R: Rng, { - fn write(&mut self, _value: F) { + fn send(&mut self, _value: F) { // no-op } - fn read(&mut self) -> F { + fn challenge(&mut self) -> F { F::rand(&mut self.rng) } } + +impl<'a, F, R> VerifierTranscript for TestTranscript<'a, R> +where + F: Field, + R: Rng, +{ + type Error = core::convert::Infallible; + + fn receive(&mut self) -> Result { + Ok(F::rand(&mut self.rng)) + } + + fn challenge(&mut self) -> F { + F::rand(&mut self.rng) + } +} + +pub type SanityTranscript<'a, R> = TestTranscript<'a, R>; diff --git a/src/transcript/spongefish.rs b/src/transcript/spongefish.rs index cbd0a797..a2ea64f6 100644 --- a/src/transcript/spongefish.rs +++ b/src/transcript/spongefish.rs @@ -2,46 +2,46 @@ use ark_ff::Field; use ark_std::rand::{CryptoRng, RngCore}; use spongefish::{Decoding, Encoding, ProverState, StdHash}; -use crate::transcript::Transcript; +use crate::transcript::ProverTranscript; -/// Newtype wrapper around spongefish's [`ProverState`] so we can implement [`Transcript`]. +/// Spongefish prover transcript. /// -/// Uses the codec-level API (`prover_message` / `verifier_message`) which is compatible -/// with the new spongefish `domain_separator!` macro. +/// Implements [`ProverTranscript`] only — the verifier side should wrap +/// spongefish's `VerifierState` and implement [`VerifierTranscript`](super::VerifierTranscript). pub struct SpongefishTranscript( pub ProverState, ); -impl Transcript for SpongefishTranscript +impl ProverTranscript for SpongefishTranscript where F: Field + Encoding<[u8]> + Decoding<[u8]>, R: RngCore + CryptoRng, { - fn read(&mut self) -> F { - self.0.verifier_message::() + fn send(&mut self, value: F) { + self.0.prover_message(&value); } - fn write(&mut self, value: F) { - self.0.prover_message(&value); + fn challenge(&mut self) -> F { + self.0.verifier_message::() } } -/// Blanket impl so raw `ProverState` can be used as a `Transcript` directly. -impl Transcript for spongefish::ProverState +/// Blanket impl so raw `ProverState` can be used as a `ProverTranscript` directly. +impl ProverTranscript for spongefish::ProverState where F: Field + Encoding<[H::U]> + Decoding<[H::U]> + spongefish::NargSerialize, H: spongefish::DuplexSpongeInterface, R: RngCore + CryptoRng, { - fn read(&mut self) -> F { - self.verifier_message::() - } - fn write(&mut self, value: F) { + fn send(&mut self, value: F) { self.prover_message(&value); } + + fn challenge(&mut self) -> F { + self.verifier_message::() + } } -// Optional helpers so it's easy to get the prover state back out. impl SpongefishTranscript where R: RngCore + CryptoRng, diff --git a/src/transcript/transcript.rs b/src/transcript/transcript.rs index 819f36ac..7fc639df 100644 --- a/src/transcript/transcript.rs +++ b/src/transcript/transcript.rs @@ -1,4 +1,43 @@ -pub trait Transcript { - fn read(&mut self) -> T; - fn write(&mut self, value: T); +/// Prover-side Fiat-Shamir transcript. +/// +/// The prover absorbs messages into the sponge and squeezes challenges. +/// No `receive` — the prover never reads prover messages. +/// +/// ```ignore +/// transcript.send(c0); // absorb prover message +/// transcript.send(c2); // absorb prover message +/// let r = transcript.challenge(); // squeeze challenge +/// ``` +pub trait ProverTranscript { + /// Absorb a prover message into the transcript. + fn send(&mut self, value: T); + + /// Squeeze a verifier challenge from the transcript. + fn challenge(&mut self) -> T; +} + +/// Verifier-side Fiat-Shamir transcript. +/// +/// The verifier reads (absorbs + decodes) prover messages and squeezes +/// the same challenges as the prover. No `send` — the verifier never +/// produces prover messages. +/// +/// ```ignore +/// let c0 = transcript.receive()?; // absorb + decode prover message +/// let c2 = transcript.receive()?; // absorb + decode prover message +/// let r = transcript.challenge(); // squeeze challenge (same as prover) +/// ``` +pub trait VerifierTranscript { + type Error: core::fmt::Debug; + + /// Read a prover message from the transcript. + /// + /// Absorbs + decodes the next prover message. Returns `Err` if the + /// transcript data is malformed or exhausted. + fn receive(&mut self) -> Result; + + /// Squeeze a verifier challenge from the transcript. + /// + /// Deterministic given the absorbed state — same on prover and verifier. + fn challenge(&mut self) -> T; } diff --git a/src/verifier.rs b/src/verifier.rs new file mode 100644 index 00000000..10be0680 --- /dev/null +++ b/src/verifier.rs @@ -0,0 +1,152 @@ +//! Sumcheck verifier (Thaler Proposition 4.1). +//! +//! [`sumcheck_verify()`] checks a sumcheck proof against a claimed sum +//! and returns the challenges along with the final claim. +//! +//! The verifier handles round checks (consistency, Lagrange interpolation). +//! The **oracle check** — verifying `final_claim == g(r)` — is the caller's +//! responsibility. This separation matches the textbook decomposition and +//! every real-world usage pattern: +//! +//! - **Standalone:** compare `result.final_claim == proof.final_value` +//! - **Composed (WHIR, GKR):** pass `result.final_claim` to the next layer +//! - **Custom (WARP):** compute the expected value from `result.challenges` + +extern crate alloc; +use crate::field::SumcheckField; +use crate::proof::SumcheckError; +use crate::transcript::VerifierTranscript; +use alloc::vec; +use alloc::vec::Vec; + +/// Output of [`sumcheck_verify`]: the verifier challenges and final claim. +/// +/// The caller **must** verify `final_claim` — either by direct comparison, +/// PCS opening, or delegation to the next protocol layer. +#[derive(Clone, Debug)] +pub struct SumcheckResult { + /// Verifier challenges `r_1, ..., r_v`. + pub challenges: Vec, + /// The reduced claim after all rounds: `g_v(r_v)`. + /// + /// The caller must verify this equals `g(r_1, ..., r_v)` via an + /// oracle query, polynomial commitment opening, or delegation. + pub final_claim: F, +} + +/// Verify a sum-check proof against a claimed sum. +/// +/// For each round j: +/// 1. Reads `degree + 1` evaluations from the transcript. +/// 2. Checks `g_j(0) + g_j(1) == current_claim`. +/// 3. Invokes `hook(round, transcript)`. +/// 4. Reads the verifier challenge `r_j`. +/// 5. Updates `current_claim = g_j(r_j)` via Lagrange interpolation. +/// +/// Returns [`SumcheckResult`] containing the challenges and final claim. +/// The caller is responsible for the oracle check — verifying that +/// `final_claim == g(r_1, ..., r_v)`. +pub fn sumcheck_verify>( + claimed_sum: F, + expected_degree: usize, + num_rounds: usize, + transcript: &mut T, + mut hook: impl FnMut(usize, &mut T) -> Result<(), SumcheckError>, +) -> Result, SumcheckError> { + let mut claim = claimed_sum; + let mut challenges = Vec::with_capacity(num_rounds); + + for round in 0..num_rounds { + // Receive round polynomial evaluations from the prover. + let num_evals = expected_degree + 1; + let mut evals = Vec::with_capacity(num_evals); + for _ in 0..num_evals { + let v = transcript + .receive() + .map_err(|_| SumcheckError::TranscriptError { round })?; + evals.push(v); + } + + // Consistency check: g_j(0) + g_j(1) == claim. + let sum_01 = evals[0] + evals[1]; + if sum_01 != claim { + return Err(SumcheckError::ConsistencyCheck { round }); + } + + // Per-round hook (e.g., PoW verification for WHIR). + hook(round, transcript)?; + + // Squeeze verifier challenge. + let r = transcript.challenge(); + challenges.push(r); + + // Update claim: g_j(r_j) via Lagrange interpolation. + claim = evaluate_from_evals(&evals, r); + } + + Ok(SumcheckResult { + challenges, + final_claim: claim, + }) +} + +/// Evaluate a univariate polynomial from its evaluations at `{0, 1, ..., d}` +/// at an arbitrary point `r`. +/// +/// Uses Lagrange interpolation: +/// g(r) = Σ_i g(i) · Π_{j≠i} (r − j) / (i − j) +fn evaluate_from_evals(evals: &[F], r: F) -> F { + let d = evals.len(); // degree + 1 + if d == 0 { + return F::ZERO; + } + if d == 1 { + return evals[0]; + } + if d == 2 { + // Linear: g(r) = g(0) + r·(g(1) − g(0)). + return evals[0] + r * (evals[1] - evals[0]); + } + + // General case: Lagrange interpolation at {0, 1, ..., d-1}. + // Precompute (r − j) for all j. + let r_minus: Vec = (0..d).map(|j| r - F::from_u64(j as u64)).collect(); + + // Precompute prefix/suffix products of (r − j). + let mut prefix = vec![F::ONE; d + 1]; + for i in 0..d { + prefix[i + 1] = prefix[i] * r_minus[i]; + } + let mut suffix = vec![F::ONE; d + 1]; + for i in (0..d).rev() { + suffix[i] = suffix[i + 1] * r_minus[i]; + } + + // Precompute 1 / (i − j) for all j ≠ i, accumulated as products. + // barycentric_weight[i] = 1 / Π_{j≠i} (i − j) = 1 / (i! · (d-1-i)! · (-1)^{d-1-i}) + let mut result = F::ZERO; + for i in 0..d { + let numerator = prefix[i] * suffix[i + 1]; // Π_{j≠i} (r − j) + let denom: F = barycentric_weight(i, d); + if let Some(inv) = denom.inverse() { + result += evals[i] * numerator * inv; + } + } + result +} + +/// Barycentric weight: Π_{j≠i, 0≤j(i: usize, d: usize) -> F { + let mut w = F::ONE; + for j in 0..d { + if j != i { + let diff = i as i64 - j as i64; + if diff > 0 { + w *= F::from_u64(diff as u64); + } else { + w *= -F::from_u64((-diff) as u64); + } + } + } + w +} diff --git a/tests/adversarial_verifier.rs b/tests/adversarial_verifier.rs new file mode 100644 index 00000000..1c06b25d --- /dev/null +++ b/tests/adversarial_verifier.rs @@ -0,0 +1,342 @@ +//! Adversarial tests for `sumcheck_verify`. +//! +//! Verifies that honest proofs are accepted and corrupted proofs are rejected +//! across all three prover types: multilinear, inner product, and GKR. + +use ark_ff::{AdditiveGroup, UniformRand}; +use ark_std::rand::{rngs::StdRng, SeedableRng}; + +use effsc::noop_hook_verify; +use effsc::proof::{SumcheckError, SumcheckProof}; +use effsc::provers::inner_product::InnerProductProver; +use effsc::provers::multilinear::MultilinearProver; +use effsc::runner::sumcheck; +use effsc::tests::F64; +use effsc::transcript::{ProverTranscript, VerifierTranscript}; +use effsc::verifier::sumcheck_verify; + +// ─── Replay transcript ──────────────────────────────────────────────────── +// +// Records prover messages and challenges during proving, then replays them +// to sumcheck_verify. Both sides see the same sequence of field elements. + +struct ReplayTranscript { + /// Recorded field elements (round poly evals interleaved with challenges). + tape: Vec, + /// Current read position for the verifier. + cursor: usize, + /// RNG for generating challenges. + rng: StdRng, +} + +impl ReplayTranscript { + fn new(seed: u64) -> Self { + Self { + tape: Vec::new(), + cursor: 0, + rng: StdRng::seed_from_u64(seed), + } + } + + /// Reset the cursor to replay from the beginning. + fn rewind(&mut self) { + self.cursor = 0; + } +} + +impl ProverTranscript for ReplayTranscript { + fn send(&mut self, value: F64) { + self.tape.push(value); + } + + fn challenge(&mut self) -> F64 { + let c = F64::rand(&mut self.rng); + self.tape.push(c); + c + } +} + +impl VerifierTranscript for ReplayTranscript { + type Error = core::convert::Infallible; + + fn receive(&mut self) -> Result { + let v = self.tape[self.cursor]; + self.cursor += 1; + Ok(v) + } + + fn challenge(&mut self) -> F64 { + let v = self.tape[self.cursor]; + self.cursor += 1; + v + } +} + +// ─── Helpers ─────────────────────────────────────────────────────────────── + +const SEED: u64 = 0xAD0E_0001; + +fn make_multilinear_proof( + num_vars: usize, + transcript: &mut ReplayTranscript, +) -> (F64, SumcheckProof) { + let n = 1 << num_vars; + let mut rng = StdRng::seed_from_u64(SEED); + let evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let mut prover = MultilinearProver::new(evals); + let proof = sumcheck(&mut prover, num_vars, transcript, |_, _| {}); + (claimed_sum, proof) +} + +fn make_inner_product_proof( + num_vars: usize, + transcript: &mut ReplayTranscript, +) -> (F64, SumcheckProof) { + let n = 1 << num_vars; + let mut rng = StdRng::seed_from_u64(SEED); + let a: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + let claimed_sum: F64 = a.iter().zip(&b).map(|(&x, &y)| x * y).sum(); + + let mut prover = InnerProductProver::new(a, b); + let proof = sumcheck(&mut prover, num_vars, transcript, |_, _| {}); + (claimed_sum, proof) +} + +// ─── Multilinear: honest accept ─────────────────────────────────────────── + +#[test] +fn multilinear_honest_proof_accepted() { + let mut t = ReplayTranscript::new(42); + let (claimed_sum, proof) = make_multilinear_proof(6, &mut t); + + t.rewind(); + let result = sumcheck_verify(claimed_sum, 1, 6, &mut t, noop_hook_verify); + assert!(result.is_ok()); + let r = result.unwrap(); + assert_eq!(r.final_claim, proof.final_value); +} + +// ─── Multilinear: corrupted round poly ──────────────────────────────────── + +#[test] +fn multilinear_corrupted_round_poly_rejected() { + let mut t = ReplayTranscript::new(42); + let (claimed_sum, _proof) = make_multilinear_proof(6, &mut t); + + // Corrupt the first evaluation of round 2 in the tape. + // Each round for degree-1: 2 evals + 1 challenge = 3 elements. + // Round 2 starts at offset 6, corrupt index 6. + t.tape[6] += F64::from(1u64); + + t.rewind(); + let result = sumcheck_verify(claimed_sum, 1, 6, &mut t, noop_hook_verify); + assert!(result.is_err()); + match result.unwrap_err() { + SumcheckError::ConsistencyCheck { round } => assert_eq!(round, 2), + e => panic!("expected ConsistencyCheck at round 2, got {e:?}"), + } +} + +// ─── Multilinear: wrong claimed sum ─────────────────────────────────────── + +#[test] +fn multilinear_wrong_claimed_sum_rejected() { + let mut t = ReplayTranscript::new(42); + let (claimed_sum, _proof) = make_multilinear_proof(6, &mut t); + + t.rewind(); + let result = sumcheck_verify( + claimed_sum + F64::from(1u64), // wrong sum + 1, + 6, + &mut t, + noop_hook_verify, + ); + assert!(result.is_err()); + match result.unwrap_err() { + SumcheckError::ConsistencyCheck { round } => assert_eq!(round, 0), + e => panic!("expected ConsistencyCheck at round 0, got {e:?}"), + } +} + +// ─── Multilinear: caller catches wrong final value ──────────────────────── + +#[test] +fn multilinear_caller_catches_wrong_final_value() { + let mut t = ReplayTranscript::new(42); + let (claimed_sum, proof) = make_multilinear_proof(6, &mut t); + + t.rewind(); + let r = sumcheck_verify(claimed_sum, 1, 6, &mut t, noop_hook_verify).unwrap(); + + // The verifier returns the final claim. The caller checks it. + assert_eq!(r.final_claim, proof.final_value); + + // If the prover lied about final_value, the caller catches it: + let wrong_value = proof.final_value + F64::from(1u64); + assert_ne!(r.final_claim, wrong_value); +} + +// ─── Inner product: honest accept ───────────────────────────────────────── + +#[test] +fn inner_product_honest_proof_accepted() { + let mut t = ReplayTranscript::new(77); + let (claimed_sum, proof) = make_inner_product_proof(6, &mut t); + + t.rewind(); + let r = sumcheck_verify(claimed_sum, 2, 6, &mut t, noop_hook_verify).unwrap(); + assert_eq!(r.final_claim, proof.final_value); +} + +// ─── Inner product: corrupted round poly ────────────────────────────────── + +#[test] +fn inner_product_corrupted_round_poly_rejected() { + let mut t = ReplayTranscript::new(77); + let (claimed_sum, _proof) = make_inner_product_proof(6, &mut t); + + // Degree-2: 3 evals + 1 challenge = 4 elements per round. + // Corrupt round 1, first eval at offset 4. + t.tape[4] += F64::from(1u64); + + t.rewind(); + let result = sumcheck_verify(claimed_sum, 2, 6, &mut t, noop_hook_verify); + assert!(result.is_err()); + match result.unwrap_err() { + SumcheckError::ConsistencyCheck { round } => assert_eq!(round, 1), + e => panic!("expected ConsistencyCheck at round 1, got {e:?}"), + } +} + +// ─── Inner product: caller catches wrong final value ────────────────────── + +#[test] +fn inner_product_caller_catches_wrong_final_value() { + let mut t = ReplayTranscript::new(77); + let (claimed_sum, proof) = make_inner_product_proof(6, &mut t); + + t.rewind(); + let r = sumcheck_verify(claimed_sum, 2, 6, &mut t, noop_hook_verify).unwrap(); + assert_eq!(r.final_claim, proof.final_value); + + let wrong_value = proof.final_value + F64::from(1u64); + assert_ne!(r.final_claim, wrong_value); +} + +// ─── GKR: adversarial tests with deferred oracle check ─────────────────── + +#[cfg(feature = "arkworks")] +mod gkr_tests { + use super::*; + use effsc::provers::gkr::GkrProver; + + const GKR_SEED: u64 = 0xBEEF_0A70; + + fn make_gkr_proof( + k: usize, + transcript: &mut ReplayTranscript, + ) -> (F64, SumcheckProof, GkrProver) { + let n = 1 << k; + let n_bc = n * n; + let mut rng = StdRng::seed_from_u64(GKR_SEED); + + let add_evals: Vec = (0..n_bc).map(|_| F64::rand(&mut rng)).collect(); + let mult_evals: Vec = (0..n_bc).map(|_| F64::rand(&mut rng)).collect(); + let w_evals: Vec = (0..n).map(|_| F64::rand(&mut rng)).collect(); + + let mut expected_sum = F64::ZERO; + for b in 0..n { + for c in 0..n { + let idx = b * n + c; + let wb = w_evals[b]; + let wc = w_evals[c]; + expected_sum += add_evals[idx] * (wb + wc) + mult_evals[idx] * (wb * wc); + } + } + + let mut prover = GkrProver::new(add_evals, mult_evals, w_evals); + let num_rounds = 2 * k; + let proof = sumcheck(&mut prover, num_rounds, transcript, |_, _| {}); + (expected_sum, proof, prover) + } + + #[test] + fn gkr_honest_proof_accepted() { + let k = 3; + let mut t = ReplayTranscript::new(99); + let (claimed_sum, proof, _prover) = make_gkr_proof(k, &mut t); + + t.rewind(); + let r = sumcheck_verify(claimed_sum, 2, 2 * k, &mut t, noop_hook_verify).unwrap(); + + // GKR: final_claim is correct, caller passes it to the next layer. + assert_eq!(r.final_claim, proof.final_value); + } + + #[test] + fn gkr_corrupted_round_poly_rejected() { + let k = 3; + let mut t = ReplayTranscript::new(99); + let (claimed_sum, _proof, _prover) = make_gkr_proof(k, &mut t); + + // Degree-2: 3 evals + 1 challenge = 4 elements per round. + // Corrupt round 3, first eval at offset 12. + t.tape[12] += F64::from(1u64); + + t.rewind(); + let result = sumcheck_verify(claimed_sum, 2, 2 * k, &mut t, noop_hook_verify); + assert!(result.is_err()); + match result.unwrap_err() { + SumcheckError::ConsistencyCheck { round } => assert_eq!(round, 3), + e => panic!("expected ConsistencyCheck at round 3, got {e:?}"), + } + } + + #[test] + fn gkr_wrong_claimed_sum_rejected() { + let k = 3; + let mut t = ReplayTranscript::new(99); + let (claimed_sum, _proof, _prover) = make_gkr_proof(k, &mut t); + + t.rewind(); + let result = sumcheck_verify( + claimed_sum + F64::from(1u64), + 2, + 2 * k, + &mut t, + noop_hook_verify, + ); + assert!(result.is_err()); + match result.unwrap_err() { + SumcheckError::ConsistencyCheck { round } => assert_eq!(round, 0), + e => panic!("expected ConsistencyCheck at round 0, got {e:?}"), + } + } + + /// Demonstrates that the caller is responsible for the oracle check. + /// sumcheck_verify returns final_claim; the caller must verify it. + #[test] + fn gkr_caller_must_check_final_claim() { + let k = 3; + let mut t = ReplayTranscript::new(99); + let (claimed_sum, proof, _prover) = make_gkr_proof(k, &mut t); + + t.rewind(); + let r = sumcheck_verify(claimed_sum, 2, 2 * k, &mut t, noop_hook_verify).unwrap(); + + // The verifier gives the caller the final_claim. + // A composed caller (GKR) passes it to the next layer. + // A standalone caller checks it directly: + assert_eq!(r.final_claim, proof.final_value); + + // If the caller forgets to check, a lying prover could claim + // a different final_value. The round checks still pass — only + // the oracle check catches it. + let lying_final_value = proof.final_value + F64::from(1u64); + assert_ne!(r.final_claim, lying_final_value); + } +} diff --git a/tests/canonical_api.rs b/tests/canonical_api.rs new file mode 100644 index 00000000..63a751d7 --- /dev/null +++ b/tests/canonical_api.rs @@ -0,0 +1,253 @@ +//! Integration tests for the canonical sumcheck API (Thaler §4.1). +//! +//! Tests the `SumcheckProver` trait, `sumcheck()` runner, +//! `MultilinearProver`, and `InnerProductProver`. + +use ark_ff::{AdditiveGroup, Field, UniformRand}; +use ark_std::rand::{rngs::StdRng, SeedableRng}; + +use effsc::provers::inner_product::InnerProductProver; +use effsc::provers::multilinear::MultilinearProver; +use effsc::runner::sumcheck; +use effsc::tests::F64; +use effsc::transcript::SanityTranscript; + +const SEED: u64 = 0xDEAD_BEEF; + +fn rng() -> StdRng { + StdRng::seed_from_u64(SEED) +} + +/// Independent MLE evaluation (MSB half-split fold). +fn mle_eval(evals: &[F64], point: &[F64]) -> F64 { + let mut current = evals.to_vec(); + for &r in point { + let half = current.len() / 2; + let (low, high) = current.split_at(half); + current = low + .iter() + .zip(high) + .map(|(l, h)| *l + (*h - *l) * r) + .collect(); + } + current[0] +} + +/// Evaluate degree-d polynomial from evaluations at {0,1,...,d} at point r +/// via Lagrange interpolation. +fn lagrange_eval(evals: &[F64], r: F64) -> F64 { + let d = evals.len(); + let mut result = F64::ZERO; + for i in 0..d { + let mut basis = ::ONE; + for j in 0..d { + if j != i { + let ni = F64::from(i as u64); + let nj = F64::from(j as u64); + basis *= (r - nj) / (ni - nj); + } + } + result += evals[i] * basis; + } + result +} + +/// Verify a SumcheckProof by checking consistency equations. +fn verify_proof(claimed_sum: F64, round_polys: &[Vec], challenges: &[F64], final_value: F64) { + let num_rounds = round_polys.len(); + assert_eq!(challenges.len(), num_rounds); + + let mut claim = claimed_sum; + for (j, (rp, &r)) in round_polys.iter().zip(challenges).enumerate() { + // q_j(0) + q_j(1) == claim + let sum_01 = rp[0] + rp[1]; + assert_eq!(sum_01, claim, "round {j}: consistency check failed"); + // Update claim = q_j(r_j). + claim = lagrange_eval(rp, r); + } + assert_eq!(claim, final_value, "final value mismatch"); +} + +// ─── MultilinearProver ───────────────────────────────────────────────────── + +#[test] +fn multilinear_full_roundtrip() { + let num_vars = 8; + let n = 1 << num_vars; + + let mut r = rng(); + let evals: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let claimed_sum: F64 = evals.iter().copied().sum(); + + let mut prover = MultilinearProver::new(evals.clone()); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_vars, &mut t, |_, _| {}); + + // Proof structure. + assert_eq!(proof.round_polys.len(), num_vars); + assert_eq!(proof.challenges.len(), num_vars); + for rp in &proof.round_polys { + assert_eq!(rp.len(), 2, "degree-1 round poly should have 2 evaluations"); + } + + // Verify consistency. + verify_proof( + claimed_sum, + &proof.round_polys, + &proof.challenges, + proof.final_value, + ); + + // Final value matches independent MLE evaluation. + assert_eq!(proof.final_value, mle_eval(&evals, &proof.challenges)); + + // Prover post-state. + assert_eq!(prover.evals().len(), 1); + assert_eq!(prover.evals()[0], proof.final_value); +} + +#[test] +fn multilinear_partial_then_continue() { + let num_vars = 8; + let split = 3; + let n = 1 << num_vars; + + let mut r = rng(); + let evals: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + // Full. + let mut prover_full = MultilinearProver::new(evals.clone()); + let mut trng = rng(); + let mut t_full = SanityTranscript::new(&mut trng); + let full = sumcheck(&mut prover_full, num_vars, &mut t_full, |_, _| {}); + + // Split. + let mut prover = MultilinearProver::new(evals); + let mut trng2 = rng(); + let mut t_split = SanityTranscript::new(&mut trng2); + let first = sumcheck(&mut prover, split, &mut t_split, |_, _| {}); + let second = sumcheck(&mut prover, num_vars - split, &mut t_split, |_, _| {}); + + // Round polys concatenate to match full. + let mut combined_polys = first.round_polys.clone(); + combined_polys.extend(second.round_polys.iter().cloned()); + assert_eq!(combined_polys, full.round_polys); + + let mut combined_challenges = first.challenges.clone(); + combined_challenges.extend(second.challenges.iter().copied()); + assert_eq!(combined_challenges, full.challenges); + + assert_eq!(second.final_value, full.final_value); +} + +#[test] +fn multilinear_hook_fires() { + use std::cell::RefCell; + let num_vars = 5; + let n = 1 << num_vars; + + let mut r = rng(); + let evals: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let mut prover = MultilinearProver::new(evals); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + + let calls = RefCell::new(Vec::::new()); + let _ = sumcheck(&mut prover, num_vars, &mut t, |round, _| { + calls.borrow_mut().push(round); + }); + assert_eq!(calls.into_inner(), (0..num_vars).collect::>()); +} + +#[test] +fn multilinear_non_pow2() { + let n = 13_usize; + let num_rounds = n.next_power_of_two().trailing_zeros() as usize; + + let mut r = rng(); + let evals: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut prover = MultilinearProver::new(evals); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_rounds, &mut t, |_, _| {}); + + assert_eq!(proof.round_polys.len(), num_rounds); + assert_eq!(prover.evals().len(), 1); +} + +// ─── InnerProductProver ──────────────────────────────────────────────────── + +#[test] +fn inner_product_full_roundtrip() { + let num_vars = 8; + let n = 1 << num_vars; + + let mut r = rng(); + let a: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let claimed_sum: F64 = a.iter().zip(&b).map(|(&x, &y)| x * y).sum(); + + let mut prover = InnerProductProver::new(a.clone(), b.clone()); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let proof = sumcheck(&mut prover, num_vars, &mut t, |_, _| {}); + + // Proof structure. + assert_eq!(proof.round_polys.len(), num_vars); + assert_eq!(proof.challenges.len(), num_vars); + for rp in &proof.round_polys { + assert_eq!(rp.len(), 3, "degree-2 round poly should have 3 evaluations"); + } + + // Verify consistency. + verify_proof( + claimed_sum, + &proof.round_polys, + &proof.challenges, + proof.final_value, + ); + + // Final value == f(r) * g(r). + let (fa, fb) = prover.final_evaluations(); + assert_eq!(proof.final_value, fa * fb); + + // Independent MLE check. + assert_eq!(fa, mle_eval(&a, &proof.challenges)); + assert_eq!(fb, mle_eval(&b, &proof.challenges)); +} + +#[test] +fn inner_product_partial_then_continue() { + let num_vars = 8; + let split = 3; + let n = 1 << num_vars; + + let mut r = rng(); + let a: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + // Full. + let mut prover_full = InnerProductProver::new(a.clone(), b.clone()); + let mut trng = rng(); + let mut t_full = SanityTranscript::new(&mut trng); + let full = sumcheck(&mut prover_full, num_vars, &mut t_full, |_, _| {}); + + // Split. + let mut prover = InnerProductProver::new(a, b); + let mut trng2 = rng(); + let mut t_split = SanityTranscript::new(&mut trng2); + let first = sumcheck(&mut prover, split, &mut t_split, |_, _| {}); + let second = sumcheck(&mut prover, num_vars - split, &mut t_split, |_, _| {}); + + let mut combined_polys = first.round_polys.clone(); + combined_polys.extend(second.round_polys.iter().cloned()); + assert_eq!(combined_polys, full.round_polys); + + let mut combined_challenges = first.challenges.clone(); + combined_challenges.extend(second.challenges.iter().copied()); + assert_eq!(combined_challenges, full.challenges); + + assert_eq!(second.final_value, full.final_value); +} diff --git a/tests/inner_product_sumcheck.rs b/tests/inner_product_sumcheck.rs new file mode 100644 index 00000000..6f654bef --- /dev/null +++ b/tests/inner_product_sumcheck.rs @@ -0,0 +1,354 @@ +//! Integration tests for the MSB fused inner-product sumcheck. + +use ark_ff::{AdditiveGroup, Field, UniformRand}; +use ark_std::rand::{rngs::StdRng, SeedableRng}; + +use effsc::tests::F64; +use effsc::transcript::{ProverTranscript, SanityTranscript}; +use effsc::{inner_product_sumcheck, inner_product_sumcheck_partial, ProductSumcheck}; + +const SEED: u64 = 0xA110C8ED; + +fn rng() -> StdRng { + StdRng::seed_from_u64(SEED) +} + +/// Evaluate the multilinear extension of `evals` at `point` with MSB +/// ordering (pop the top half each round). +fn multilinear_extend(evals: &[F], point: &[F]) -> F { + assert_eq!(evals.len(), 1 << point.len()); + let mut current = evals.to_vec(); + for &r in point { + let half = current.len() / 2; + let (low, high) = current.split_at(half); + current = low + .iter() + .zip(high) + .map(|(l, h)| *l + (*h - *l) * r) + .collect(); + } + current[0] +} + +#[test] +fn test_power_of_two_roundtrip() { + let num_vars = 8; + let n = 1 << num_vars; + + let mut r = rng(); + let a_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut prover_rng = rng(); + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut t_prove = SanityTranscript::new(&mut prover_rng); + let result: ProductSumcheck = + inner_product_sumcheck(&mut a, &mut b, &mut t_prove, |_, _| {}); + + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + assert_eq!(result.prover_messages.len(), num_vars); + assert_eq!(result.verifier_messages.len(), num_vars); + assert_eq!(result.final_evaluations, (a[0], b[0])); + + // Folded values match an independent MLE evaluation at the challenge point. + assert_eq!(multilinear_extend(&a_orig, &result.verifier_messages), a[0]); + assert_eq!(multilinear_extend(&b_orig, &result.verifier_messages), b[0]); +} + +#[test] +fn test_non_power_of_two_partial_runs() { + let initial_size = 13_usize; + let padded = initial_size.next_power_of_two(); + let num_rounds = padded.trailing_zeros() as usize; + + let mut r = rng(); + let a_orig: Vec = (0..initial_size).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..initial_size).map(|_| F64::rand(&mut r)).collect(); + + let mut prover_rng = rng(); + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut t = SanityTranscript::new(&mut prover_rng); + let result = inner_product_sumcheck_partial(&mut a, &mut b, &mut t, num_rounds, |_, _| {}); + assert_eq!(result.prover_messages.len(), num_rounds); + assert_eq!(result.verifier_messages.len(), num_rounds); + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); +} + +#[test] +fn test_partial_split_matches_full() { + let num_vars = 8; + let n = 1 << num_vars; + let split_at = 3; + + let mut r = rng(); + let a_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut a_full = a_orig.clone(); + let mut b_full = b_orig.clone(); + let mut full_rng = rng(); + let mut t_full = SanityTranscript::new(&mut full_rng); + let full = inner_product_sumcheck(&mut a_full, &mut b_full, &mut t_full, |_, _| {}); + + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut split_rng = rng(); + let mut t_split = SanityTranscript::new(&mut split_rng); + let first = inner_product_sumcheck_partial(&mut a, &mut b, &mut t_split, split_at, |_, _| {}); + let second = inner_product_sumcheck_partial( + &mut a, + &mut b, + &mut t_split, + num_vars - split_at, + |_, _| {}, + ); + + let mut split_prover = first.prover_messages.clone(); + split_prover.extend(second.prover_messages.iter().copied()); + let mut split_verifier = first.verifier_messages.clone(); + split_verifier.extend(second.verifier_messages.iter().copied()); + + assert_eq!(split_prover, full.prover_messages); + assert_eq!(split_verifier, full.verifier_messages); + assert_eq!(second.final_evaluations, full.final_evaluations); + assert_eq!(first.final_evaluations, (F64::ZERO, F64::ZERO)); +} + +#[test] +fn test_hook_called_once_per_round() { + use std::cell::RefCell; + let num_vars = 6; + let n = 1 << num_vars; + + let mut r = rng(); + let mut a: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let mut b: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + + let calls = RefCell::new(Vec::::new()); + let result = inner_product_sumcheck(&mut a, &mut b, &mut t, |round, _| { + calls.borrow_mut().push(round); + }); + assert_eq!(result.prover_messages.len(), num_vars); + assert_eq!(calls.into_inner(), (0..num_vars).collect::>()); +} + +#[test] +fn test_zero_rounds_is_identity() { + let mut r = rng(); + let a_orig: Vec = (0..8).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..8).map(|_| F64::rand(&mut r)).collect(); + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + + let result = inner_product_sumcheck_partial(&mut a, &mut b, &mut t, 0, |_, _| {}); + assert!(result.prover_messages.is_empty()); + assert!(result.verifier_messages.is_empty()); + assert_eq!(a, a_orig); + assert_eq!(b, b_orig); +} + +#[test] +fn test_prover_msg_is_difference_form() { + // Round-0 (c0, c2) in difference form: + // c0 = Σ a_lo · b_lo (= q(0)) + // c2 = Σ (a_hi − a_lo)·(b_hi − b_lo) (= [x²] q(x)) + let n = 16_usize; + let mut r = rng(); + let a: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut a_mut = a.clone(); + let mut b_mut = b.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let result = inner_product_sumcheck_partial(&mut a_mut, &mut b_mut, &mut t, 1, |_, _| {}); + let (c0, c2) = result.prover_messages[0]; + + let half = n / 2; + let expected_c0: F64 = a[..half].iter().zip(&b[..half]).map(|(x, y)| *x * *y).sum(); + let expected_c2: F64 = a[..half] + .iter() + .zip(&a[half..]) + .zip(b[..half].iter().zip(&b[half..])) + .map(|((a0, a1), (b0, b1))| (*a1 - *a0) * (*b1 - *b0)) + .sum(); + assert_eq!(c0, expected_c0); + assert_eq!(c2, expected_c2); +} + +#[test] +fn test_deterministic_under_same_seed() { + let n = 1 << 5; + let mut r = rng(); + let a_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let run = || -> _ { + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + inner_product_sumcheck(&mut a, &mut b, &mut t, |_, _| {}) + }; + let r1 = run(); + let r2 = run(); + assert_eq!(r1.prover_messages, r2.prover_messages); + assert_eq!(r1.verifier_messages, r2.verifier_messages); + assert_eq!(r1.final_evaluations, r2.final_evaluations); +} + +/// Reference unfused half-split prover. Runs the protocol by folding then +/// computing each round with plain scalar loops. Transcript must match the +/// fused path bit-for-bit. +fn reference_unfused(a_orig: &[F64], b_orig: &[F64]) -> ProductSumcheck { + let mut a = a_orig.to_vec(); + let mut b = b_orig.to_vec(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let num_rounds = if a.is_empty() { + 0 + } else { + a.len().next_power_of_two().trailing_zeros() as usize + }; + + let mut prover_messages = Vec::with_capacity(num_rounds); + let mut verifier_messages = Vec::with_capacity(num_rounds); + let mut w: Option = None; + + for _ in 0..num_rounds { + if let Some(weight) = w { + fold_in_place(&mut a, weight); + fold_in_place(&mut b, weight); + } + let (c0, c2) = compute_ref(&a, &b); + prover_messages.push((c0, c2)); + t.send(c0); + t.send(c2); + let r: F64 = t.challenge(); + verifier_messages.push(r); + w = Some(r); + } + if let Some(weight) = w { + fold_in_place(&mut a, weight); + fold_in_place(&mut b, weight); + } + + let final_evaluations = if a.len() == 1 { + (a[0], b[0]) + } else { + (F64::ZERO, F64::ZERO) + }; + ProductSumcheck { + prover_messages, + verifier_messages, + final_evaluations, + } +} + +fn fold_in_place(values: &mut Vec, weight: F) { + if values.len() <= 1 { + return; + } + let half = values.len().next_power_of_two() >> 1; + let (low, high) = values.split_at_mut(half); + let (low, tail) = low.split_at_mut(high.len()); + for (lo, hi) in low.iter_mut().zip(high.iter()) { + *lo += (*hi - *lo) * weight; + } + for x in tail.iter_mut() { + *x *= F::ONE - weight; + } + values.truncate(half); +} + +fn compute_ref(a: &[F], b: &[F]) -> (F, F) { + let non_padded = a.len().min(b.len()); + let a = &a[..non_padded]; + let b = &b[..non_padded]; + if a.is_empty() { + return (F::ZERO, F::ZERO); + } + if a.len() == 1 { + return (a[0] * b[0], F::ZERO); + } + let half = a.len().next_power_of_two() >> 1; + let (a0, a1) = a.split_at(half); + let (b0, b1) = b.split_at(half); + let (a0, a0_tail) = a0.split_at(a1.len()); + let (b0, b0_tail) = b0.split_at(a1.len()); + let mut c0 = F::ZERO; + let mut c2 = F::ZERO; + for ((&x0, &x1), (&y0, &y1)) in a0.iter().zip(a1).zip(b0.iter().zip(b1)) { + c0 += x0 * y0; + c2 += (x1 - x0) * (y1 - y0); + } + let tail: F = a0_tail.iter().zip(b0_tail).map(|(x, y)| *x * *y).sum(); + (c0 + tail, c2 + tail) +} + +#[test] +fn test_fused_matches_unfused_reference_pow2() { + for &num_vars in &[1_usize, 2, 4, 7, 10] { + let n = 1 << num_vars; + let mut r = rng(); + let a_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let ref_result = reference_unfused(&a_orig, &b_orig); + + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let fused = inner_product_sumcheck(&mut a, &mut b, &mut t, |_, _| {}); + + assert_eq!(fused.prover_messages, ref_result.prover_messages, "n={n}"); + assert_eq!( + fused.verifier_messages, ref_result.verifier_messages, + "n={n}" + ); + assert_eq!( + fused.final_evaluations, ref_result.final_evaluations, + "n={n}" + ); + } +} + +#[test] +fn test_fused_matches_unfused_reference_non_pow2() { + for &n in &[3_usize, 5, 13, 33, 100] { + let mut r = rng(); + let a_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let b_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let ref_result = reference_unfused(&a_orig, &b_orig); + + let mut a = a_orig.clone(); + let mut b = b_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let fused = inner_product_sumcheck(&mut a, &mut b, &mut t, |_, _| {}); + + assert_eq!(fused.prover_messages, ref_result.prover_messages, "n={n}"); + assert_eq!( + fused.verifier_messages, ref_result.verifier_messages, + "n={n}" + ); + assert_eq!( + fused.final_evaluations, ref_result.final_evaluations, + "n={n}" + ); + } +} + +// Silence unused-import warning when built without tests touching AdditiveGroup. +const _: F64 = ::ZERO; diff --git a/tests/multilinear_sumcheck.rs b/tests/multilinear_sumcheck.rs new file mode 100644 index 00000000..d3f82003 --- /dev/null +++ b/tests/multilinear_sumcheck.rs @@ -0,0 +1,299 @@ +//! Integration tests for the MSB fused multilinear sumcheck. + +use ark_ff::{AdditiveGroup, Field, UniformRand}; +use ark_std::rand::{rngs::StdRng, SeedableRng}; + +use effsc::tests::F64; +use effsc::transcript::{ProverTranscript, SanityTranscript}; +use effsc::{multilinear_sumcheck, multilinear_sumcheck_partial, Sumcheck}; + +const SEED: u64 = 0xA110C8ED; + +fn rng() -> StdRng { + StdRng::seed_from_u64(SEED) +} + +fn multilinear_extend(evals: &[F], point: &[F]) -> F { + assert_eq!(evals.len(), 1 << point.len()); + let mut current = evals.to_vec(); + for &r in point { + let half = current.len() / 2; + let (low, high) = current.split_at(half); + current = low + .iter() + .zip(high) + .map(|(l, h)| *l + (*h - *l) * r) + .collect(); + } + current[0] +} + +#[test] +fn test_power_of_two_roundtrip() { + let num_vars = 8; + let n = 1 << num_vars; + + let mut r = rng(); + let v_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut prover_rng = rng(); + let mut v = v_orig.clone(); + let mut t_prove = SanityTranscript::new(&mut prover_rng); + let result: Sumcheck = multilinear_sumcheck(&mut v, &mut t_prove, |_, _| {}); + + assert_eq!(v.len(), 1); + assert_eq!(result.prover_messages.len(), num_vars); + assert_eq!(result.verifier_messages.len(), num_vars); + assert_eq!(result.final_evaluation, v[0]); + + // Folded value matches an independent MLE evaluation. + assert_eq!(multilinear_extend(&v_orig, &result.verifier_messages), v[0]); + + // Round-0 consistency: s0 + s1 == Σ v. + let claim: F64 = v_orig.iter().copied().sum(); + let (s0, s1) = result.prover_messages[0]; + assert_eq!(s0 + s1, claim); +} + +#[test] +fn test_non_power_of_two_partial_runs() { + let initial_size = 13_usize; + let padded = initial_size.next_power_of_two(); + let num_rounds = padded.trailing_zeros() as usize; + + let mut r = rng(); + let v_orig: Vec = (0..initial_size).map(|_| F64::rand(&mut r)).collect(); + + let mut prover_rng = rng(); + let mut v = v_orig.clone(); + let mut t = SanityTranscript::new(&mut prover_rng); + let result = multilinear_sumcheck_partial(&mut v, &mut t, num_rounds, |_, _| {}); + assert_eq!(result.prover_messages.len(), num_rounds); + assert_eq!(result.verifier_messages.len(), num_rounds); + assert_eq!(v.len(), 1); +} + +#[test] +fn test_partial_split_matches_full() { + let num_vars = 8; + let n = 1 << num_vars; + let split_at = 3; + + let mut r = rng(); + let v_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut v_full = v_orig.clone(); + let mut full_rng = rng(); + let mut t_full = SanityTranscript::new(&mut full_rng); + let full = multilinear_sumcheck(&mut v_full, &mut t_full, |_, _| {}); + + let mut v = v_orig.clone(); + let mut split_rng = rng(); + let mut t_split = SanityTranscript::new(&mut split_rng); + let first = multilinear_sumcheck_partial(&mut v, &mut t_split, split_at, |_, _| {}); + let second = multilinear_sumcheck_partial(&mut v, &mut t_split, num_vars - split_at, |_, _| {}); + + let mut split_prover = first.prover_messages.clone(); + split_prover.extend(second.prover_messages.iter().copied()); + let mut split_verifier = first.verifier_messages.clone(); + split_verifier.extend(second.verifier_messages.iter().copied()); + + assert_eq!(split_prover, full.prover_messages); + assert_eq!(split_verifier, full.verifier_messages); + assert_eq!(second.final_evaluation, full.final_evaluation); + assert_eq!(first.final_evaluation, F64::ZERO); +} + +#[test] +fn test_hook_called_once_per_round() { + use std::cell::RefCell; + let num_vars = 6; + let n = 1 << num_vars; + + let mut r = rng(); + let mut v: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + + let calls = RefCell::new(Vec::::new()); + let result = multilinear_sumcheck(&mut v, &mut t, |round, _| { + calls.borrow_mut().push(round); + }); + assert_eq!(result.prover_messages.len(), num_vars); + assert_eq!(calls.into_inner(), (0..num_vars).collect::>()); +} + +#[test] +fn test_zero_rounds_is_identity() { + let mut r = rng(); + let v_orig: Vec = (0..8).map(|_| F64::rand(&mut r)).collect(); + let mut v = v_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + + let result = multilinear_sumcheck_partial(&mut v, &mut t, 0, |_, _| {}); + assert!(result.prover_messages.is_empty()); + assert!(result.verifier_messages.is_empty()); + assert_eq!(v, v_orig); +} + +#[test] +fn test_round0_msg_is_half_sums() { + let n = 16_usize; + let mut r = rng(); + let v: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let mut v_mut = v.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let result = multilinear_sumcheck_partial(&mut v_mut, &mut t, 1, |_, _| {}); + let (s0, s1) = result.prover_messages[0]; + + let half = n / 2; + let expected_s0: F64 = v[..half].iter().copied().sum(); + let expected_s1: F64 = v[half..].iter().copied().sum(); + assert_eq!(s0, expected_s0); + assert_eq!(s1, expected_s1); +} + +#[test] +fn test_deterministic_under_same_seed() { + let n = 1 << 5; + let mut r = rng(); + let v_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let run = || -> _ { + let mut v = v_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + multilinear_sumcheck(&mut v, &mut t, |_, _| {}) + }; + let r1 = run(); + let r2 = run(); + assert_eq!(r1.prover_messages, r2.prover_messages); + assert_eq!(r1.verifier_messages, r2.verifier_messages); + assert_eq!(r1.final_evaluation, r2.final_evaluation); +} + +/// Reference: unfused half-split prover. Runs fold then compute each round. +fn reference_unfused(v_orig: &[F64]) -> Sumcheck { + let mut v = v_orig.to_vec(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let num_rounds = if v.is_empty() { + 0 + } else { + v.len().next_power_of_two().trailing_zeros() as usize + }; + + let mut prover_messages = Vec::with_capacity(num_rounds); + let mut verifier_messages = Vec::with_capacity(num_rounds); + let mut w: Option = None; + + for _ in 0..num_rounds { + if let Some(weight) = w { + fold_in_place(&mut v, weight); + } + let (s0, s1) = compute_ref(&v); + prover_messages.push((s0, s1)); + t.send(s0); + t.send(s1); + let r: F64 = t.challenge(); + verifier_messages.push(r); + w = Some(r); + } + if let Some(weight) = w { + fold_in_place(&mut v, weight); + } + + let final_evaluation = if v.len() == 1 { v[0] } else { F64::ZERO }; + Sumcheck { + prover_messages, + verifier_messages, + final_evaluation, + } +} + +fn fold_in_place(values: &mut Vec, weight: F) { + if values.len() <= 1 { + return; + } + let half = values.len().next_power_of_two() >> 1; + let (low, high) = values.split_at_mut(half); + let (low, tail) = low.split_at_mut(high.len()); + for (lo, hi) in low.iter_mut().zip(high.iter()) { + *lo += (*hi - *lo) * weight; + } + for x in tail.iter_mut() { + *x *= F::ONE - weight; + } + values.truncate(half); +} + +fn compute_ref(values: &[F]) -> (F, F) { + if values.is_empty() { + return (F::ZERO, F::ZERO); + } + if values.len() == 1 { + return (values[0], F::ZERO); + } + let half = values.len().next_power_of_two() >> 1; + let (lo, hi) = values.split_at(half); + let (lo, lo_tail) = lo.split_at(hi.len()); + let mut s0 = F::ZERO; + let mut s1 = F::ZERO; + for (&l, &h) in lo.iter().zip(hi) { + s0 += l; + s1 += h; + } + let tail: F = lo_tail.iter().copied().sum(); + (s0 + tail, s1) +} + +#[test] +fn test_fused_matches_unfused_reference_pow2() { + for &num_vars in &[1_usize, 2, 4, 7, 10] { + let n = 1 << num_vars; + let mut r = rng(); + let v_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let ref_result = reference_unfused(&v_orig); + + let mut v = v_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let fused = multilinear_sumcheck(&mut v, &mut t, |_, _| {}); + + assert_eq!(fused.prover_messages, ref_result.prover_messages, "n={n}"); + assert_eq!( + fused.verifier_messages, ref_result.verifier_messages, + "n={n}" + ); + assert_eq!(fused.final_evaluation, ref_result.final_evaluation, "n={n}"); + } +} + +#[test] +fn test_fused_matches_unfused_reference_non_pow2() { + for &n in &[3_usize, 5, 13, 33, 100] { + let mut r = rng(); + let v_orig: Vec = (0..n).map(|_| F64::rand(&mut r)).collect(); + + let ref_result = reference_unfused(&v_orig); + + let mut v = v_orig.clone(); + let mut trng = rng(); + let mut t = SanityTranscript::new(&mut trng); + let fused = multilinear_sumcheck(&mut v, &mut t, |_, _| {}); + + assert_eq!(fused.prover_messages, ref_result.prover_messages, "n={n}"); + assert_eq!( + fused.verifier_messages, ref_result.verifier_messages, + "n={n}" + ); + assert_eq!(fused.final_evaluation, ref_result.final_evaluation, "n={n}"); + } +} + +// Silence unused-import warning when built without tests touching AdditiveGroup. +const _: F64 = ::ZERO; diff --git a/tests/plonky3_roundtrip.rs b/tests/plonky3_roundtrip.rs new file mode 100644 index 00000000..7dbfb87f --- /dev/null +++ b/tests/plonky3_roundtrip.rs @@ -0,0 +1,181 @@ +//! Roundtrip test demonstrating effsc with a Plonky3 field. +//! +//! Shows that any ecosystem's field type works with the sumcheck library +//! via a thin `SumcheckField` impl — no arkworks dependency required. + +use effsc::field::SumcheckField; +use effsc::noop_hook; +use effsc::provers::multilinear_lsb::MultilinearProverLSB; +use effsc::runner::sumcheck; +use effsc::transcript::{ProverTranscript, VerifierTranscript}; +use effsc::verifier::sumcheck_verify; + +use p3_field::integers::QuotientMap; +use p3_field::Field; +use p3_goldilocks::Goldilocks; + +use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +// ─── Newtype wrapper ─────────────────────────────────────────────────────── + +/// Thin wrapper around Plonky3's Goldilocks to implement `SumcheckField`. +#[derive(Copy, Clone, Debug, PartialEq)] +struct P3Goldilocks(Goldilocks); + +impl P3Goldilocks { + fn new(val: u64) -> Self { + Self(Goldilocks::from_int(val)) + } +} + +impl fmt::Display for P3Goldilocks { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl Add for P3Goldilocks { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } +} + +impl Sub for P3Goldilocks { + type Output = Self; + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } +} + +impl Mul for P3Goldilocks { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + Self(self.0 * rhs.0) + } +} + +impl Neg for P3Goldilocks { + type Output = Self; + fn neg(self) -> Self { + Self(-self.0) + } +} + +impl AddAssign for P3Goldilocks { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} + +impl SubAssign for P3Goldilocks { + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} + +impl MulAssign for P3Goldilocks { + fn mul_assign(&mut self, rhs: Self) { + self.0 *= rhs.0; + } +} + +impl Sum for P3Goldilocks { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + +impl SumcheckField for P3Goldilocks { + const ZERO: Self = Self(Goldilocks::new(0)); + const ONE: Self = Self(Goldilocks::new(1)); + + fn from_u64(val: u64) -> Self { + Self(Goldilocks::from_int(val)) + } + + fn inverse(&self) -> Option { + self.0.try_inverse().map(Self) + } +} + +// ─── Minimal transcript ──────────────────────────────────────────────────── + +/// Deterministic transcript for the roundtrip test. +struct CounterTranscript { + counter: u64, + tape: Vec, + cursor: usize, +} + +impl CounterTranscript { + fn new() -> Self { + Self { + counter: 1, + tape: Vec::new(), + cursor: 0, + } + } + + fn rewind(&mut self) { + self.cursor = 0; + } +} + +impl ProverTranscript for CounterTranscript { + fn send(&mut self, value: P3Goldilocks) { + self.tape.push(value); + } + + fn challenge(&mut self) -> P3Goldilocks { + self.counter += 1; + let c = P3Goldilocks::new(self.counter); + self.tape.push(c); + c + } +} + +impl VerifierTranscript for CounterTranscript { + type Error = core::convert::Infallible; + + fn receive(&mut self) -> Result { + let v = self.tape[self.cursor]; + self.cursor += 1; + Ok(v) + } + + fn challenge(&mut self) -> P3Goldilocks { + let v = self.tape[self.cursor]; + self.cursor += 1; + v + } +} + +// ─── Test ────────────────────────────────────────────────────────────────── + +#[test] +fn plonky3_multilinear_roundtrip() { + let num_vars = 4; + let n = 1 << num_vars; + + // Create evaluations: f(x) = x + 1 for x in 0..16. + let evals: Vec = (0..n).map(|i| P3Goldilocks::new(i as u64 + 1)).collect(); + let claimed_sum: P3Goldilocks = evals.iter().copied().sum(); + + // Prove. + let mut prover = MultilinearProverLSB::new(evals); + let mut t = CounterTranscript::new(); + let proof = sumcheck(&mut prover, num_vars, &mut t, noop_hook); + + assert_eq!(proof.round_polys.len(), num_vars); + assert_eq!(proof.challenges.len(), num_vars); + + // Verify. + t.rewind(); + let result = sumcheck_verify(claimed_sum, 1, num_vars, &mut t, |_, _| Ok(())) + .expect("verification should pass"); + + assert_eq!(result.final_claim, proof.final_value); +}