diff --git a/crates/air/benches/constraint_bench.rs b/crates/air/benches/constraint_bench.rs index f122f92..12156c7 100644 --- a/crates/air/benches/constraint_bench.rs +++ b/crates/air/benches/constraint_bench.rs @@ -8,6 +8,10 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Benchmark use zp1_primitives::M31; use zp1_air::rv32im::{CpuTraceRow, ConstraintEvaluator}; +fn sum_constraints(constraints: Vec) -> M31 { + constraints.into_iter().fold(M31::ZERO, |acc, c| acc + c) +} + /// Create a test row for bitwise AND operation. fn create_and_row() -> CpuTraceRow { let mut row = CpuTraceRow::default(); @@ -107,7 +111,7 @@ fn bench_and_bit_based(c: &mut Criterion) { let row = create_and_row(); c.bench_function("AND_bit_based", |b| { - b.iter(|| ConstraintEvaluator::and_constraint(black_box(&row))) + b.iter(|| sum_constraints(ConstraintEvaluator::and_constraint(black_box(&row)))) }); } @@ -123,7 +127,7 @@ fn bench_xor_bit_based(c: &mut Criterion) { let row = create_xor_row(); c.bench_function("XOR_bit_based", |b| { - b.iter(|| ConstraintEvaluator::xor_constraint(black_box(&row))) + b.iter(|| sum_constraints(ConstraintEvaluator::xor_constraint(black_box(&row)))) }); } @@ -142,7 +146,7 @@ fn bench_bitwise_comparison(c: &mut Criterion) { let xor_row = create_xor_row(); group.bench_function("AND/bit_based", |b| { - b.iter(|| ConstraintEvaluator::and_constraint(black_box(&and_row))) + b.iter(|| sum_constraints(ConstraintEvaluator::and_constraint(black_box(&and_row)))) }); group.bench_function("AND/lookup_based", |b| { @@ -150,7 +154,7 @@ fn bench_bitwise_comparison(c: &mut Criterion) { }); group.bench_function("XOR/bit_based", |b| { - b.iter(|| ConstraintEvaluator::xor_constraint(black_box(&xor_row))) + b.iter(|| sum_constraints(ConstraintEvaluator::xor_constraint(black_box(&xor_row)))) }); group.bench_function("XOR/lookup_based", |b| { @@ -174,7 +178,7 @@ fn bench_batch_evaluation(c: &mut Criterion) { b.iter(|| { let mut sum = M31::ZERO; for row in rows.iter() { - sum = sum + ConstraintEvaluator::and_constraint(black_box(row)); + sum = sum + sum_constraints(ConstraintEvaluator::and_constraint(black_box(row))); } sum }) diff --git a/crates/air/src/rv32im.rs b/crates/air/src/rv32im.rs index 8e98498..64d1bc4 100644 --- a/crates/air/src/rv32im.rs +++ b/crates/air/src/rv32im.rs @@ -160,12 +160,12 @@ impl Rv32imAir { Self { constraints } } - + /// Get all constraints. pub fn constraints(&self) -> &[Constraint] { &self.constraints } - + /// Get constraint count. pub fn num_constraints(&self) -> usize { self.constraints.len() @@ -221,6 +221,8 @@ pub struct CpuTraceRow { pub rs2_val_hi: M31, // Immediate value + pub imm_lo: M31, + pub imm_hi: M31, pub imm: M31, // Instruction selectors (one-hot encoded) @@ -321,7 +323,9 @@ impl CpuTraceRow { let two_16 = M31::new(1 << 16); // Recombine split fields - let imm = cols[8] + cols[9] * two_16; + let imm_lo = cols[8]; + let imm_hi = cols[9]; + let imm = imm_lo + imm_hi * two_16; Self { pc: cols[1], @@ -335,6 +339,8 @@ impl CpuTraceRow { rs1_val_hi: cols[13], rs2_val_lo: cols[14], rs2_val_hi: cols[15], + imm_lo, + imm_hi, imm, is_add: cols[16], @@ -506,7 +512,7 @@ impl ConstraintEvaluator { // For now, assume pre-processing ensures x0 writes are NOPs M31::ZERO } - + /// PC increment for sequential instructions. #[inline] pub fn pc_increment(row: &CpuTraceRow) -> M31 { @@ -517,7 +523,7 @@ impl ConstraintEvaluator { is_sequential * (row.next_pc - row.pc - four) } - + /// ADD: rd = rs1 + rs2. #[inline] pub fn add_constraint(row: &CpuTraceRow) -> (M31, M31) { @@ -554,137 +560,81 @@ impl ConstraintEvaluator { (c1, c2) } - /// AND: rd = rs1 & rs2. - /// Uses bit decomposition with 4 verification steps: - /// 1. rs1 = sum(rs1_bits[i] * 2^i) - /// 2. rs2 = sum(rs2_bits[i] * 2^i) - /// 3. and_bits[i] = rs1_bits[i] * rs2_bits[i] (boolean AND) - /// 4. rd = sum(and_bits[i] * 2^i) #[inline] - pub fn and_constraint(row: &CpuTraceRow) -> M31 { - if row.is_and == M31::ZERO { - return M31::ZERO; + fn reconstruct_16_bits(bits: &[M31; 32], start: usize) -> M31 { + let mut reconstructed = M31::ZERO; + let mut pow = M31::ONE; + for bit in bits.iter().skip(start).take(16) { + reconstructed += *bit * pow; + pow += pow; } + reconstructed + } - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut rs2_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut and_check = M31::ZERO; - - for i in 0..31 { // First 31 bits (fit in M31) - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - rs2_reconstructed += row.rs2_bits[i] * pow2; - rd_reconstructed += row.and_bits[i] * pow2; - // AND logic: and_bits[i] = rs1_bits[i] * rs2_bits[i] - and_check += row.and_bits[i] - row.rs1_bits[i] * row.rs2_bits[i]; + #[inline] + fn push_word_bit_constraints(selector: M31, lo: M31, hi: M31, bits: &[M31; 32], constraints: &mut Vec) { + constraints.push(selector * (lo - Self::reconstruct_16_bits(bits, 0))); + constraints.push(selector * (hi - Self::reconstruct_16_bits(bits, 16))); + for bit in bits { + constraints.push(selector * *bit * (*bit - M31::ONE)); } - // Bit 31 separately to handle field overflow - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); - and_check += row.and_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; - - // All 4 checks in one constraint: - // (rs1 reconstruction) + (rs2 reconstruction) + (AND logic) + (rd reconstruction) - row.is_and * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - and_check + - (rd_full - rd_reconstructed) - ) } - - /// OR: rd = rs1 | rs2. - /// Uses bit decomposition: or_bit[i] = rs1_bit[i] + rs2_bit[i] - rs1_bit[i]*rs2_bit[i]. + #[inline] - pub fn or_constraint(row: &CpuTraceRow) -> M31 { - if row.is_or == M31::ZERO { - return M31::ZERO; + fn push_and_logic_constraints(selector: M31, lhs: &[M31; 32], rhs: &[M31; 32], out: &[M31; 32], constraints: &mut Vec) { + for i in 0..32 { + constraints.push(selector * (out[i] - lhs[i] * rhs[i])); } + } - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut rs2_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut or_check = M31::ZERO; - - for i in 0..31 { - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - rs2_reconstructed += row.rs2_bits[i] * pow2; - rd_reconstructed += row.or_bits[i] * pow2; - // OR logic: or_bit = a + b - ab - let expected_or = row.rs1_bits[i] + row.rs2_bits[i] - row.rs1_bits[i] * row.rs2_bits[i]; - or_check += row.or_bits[i] - expected_or; + #[inline] + fn push_or_logic_constraints(selector: M31, lhs: &[M31; 32], rhs: &[M31; 32], out: &[M31; 32], constraints: &mut Vec) { + for i in 0..32 { + let expected = lhs[i] + rhs[i] - lhs[i] * rhs[i]; + constraints.push(selector * (out[i] - expected)); } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); - let expected_or = row.rs1_bits[31] + row.rs2_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; - or_check += row.or_bits[31] - expected_or; - - row.is_or * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - or_check + - (rd_full - rd_reconstructed) - ) } - - /// XOR: rd = rs1 ^ rs2. - /// Uses bit decomposition: xor_bit[i] = rs1_bit[i] + rs2_bit[i] - 2*rs1_bit[i]*rs2_bit[i]. + #[inline] - pub fn xor_constraint(row: &CpuTraceRow) -> M31 { - if row.is_xor == M31::ZERO { - return M31::ZERO; + fn push_xor_logic_constraints(selector: M31, lhs: &[M31; 32], rhs: &[M31; 32], out: &[M31; 32], constraints: &mut Vec) { + for i in 0..32 { + let expected = lhs[i] + rhs[i] - M31::new(2) * lhs[i] * rhs[i]; + constraints.push(selector * (out[i] - expected)); } + } - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut rs2_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut xor_check = M31::ZERO; - - for i in 0..31 { - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - rs2_reconstructed += row.rs2_bits[i] * pow2; - rd_reconstructed += row.xor_bits[i] * pow2; - // XOR logic: xor_bit = a + b - 2ab - let expected_xor = row.rs1_bits[i] + row.rs2_bits[i] - M31::new(2) * row.rs1_bits[i] * row.rs2_bits[i]; - xor_check += row.xor_bits[i] - expected_xor; - } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.rs2_bits[31] - M31::new(2) * row.rs1_bits[31] * row.rs2_bits[31]; - xor_check += row.xor_bits[31] - expected_xor; - - row.is_xor * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - xor_check + - (rd_full - rd_reconstructed) - ) + /// AND: rd = rs1 & rs2. + /// Binds bit witnesses to low/high 16-bit limbs to avoid 32-bit/M31 collisions. + #[inline] + pub fn and_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_and, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_and, row.rs2_val_lo, row.rs2_val_hi, &row.rs2_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_and, row.rd_val_lo, row.rd_val_hi, &row.and_bits, &mut constraints); + Self::push_and_logic_constraints(row.is_and, &row.rs1_bits, &row.rs2_bits, &row.and_bits, &mut constraints); + constraints + } + + /// OR: rd = rs1 | rs2. + #[inline] + pub fn or_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_or, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_or, row.rs2_val_lo, row.rs2_val_hi, &row.rs2_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_or, row.rd_val_lo, row.rd_val_hi, &row.or_bits, &mut constraints); + Self::push_or_logic_constraints(row.is_or, &row.rs1_bits, &row.rs2_bits, &row.or_bits, &mut constraints); + constraints + } + + /// XOR: rd = rs1 ^ rs2. + #[inline] + pub fn xor_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_xor, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_xor, row.rs2_val_lo, row.rs2_val_hi, &row.rs2_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_xor, row.rd_val_lo, row.rd_val_hi, &row.xor_bits, &mut constraints); + Self::push_xor_logic_constraints(row.is_xor, &row.rs1_bits, &row.rs2_bits, &row.xor_bits, &mut constraints); + constraints } // ==================== LOOKUP-BASED CONSTRAINTS ==================== @@ -894,128 +844,37 @@ impl ConstraintEvaluator { /// ANDI: rd = rs1 & imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] - pub fn andi_constraint(row: &CpuTraceRow) -> M31 { - if row.is_andi == M31::ZERO { - return M31::ZERO; - } - - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let imm_full = row.imm; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut imm_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut and_check = M31::ZERO; - - for i in 0..31 { - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - imm_reconstructed += row.imm_bits[i] * pow2; - rd_reconstructed += row.and_bits[i] * pow2; - // AND logic: and_bits[i] = rs1_bits[i] * imm_bits[i] - and_check += row.and_bits[i] - row.rs1_bits[i] * row.imm_bits[i]; - } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); - and_check += row.and_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; - - row.is_andi * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - and_check + - (rd_full - rd_reconstructed) - ) + pub fn andi_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_andi, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_andi, row.imm_lo, row.imm_hi, &row.imm_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_andi, row.rd_val_lo, row.rd_val_hi, &row.and_bits, &mut constraints); + Self::push_and_logic_constraints(row.is_andi, &row.rs1_bits, &row.imm_bits, &row.and_bits, &mut constraints); + constraints } /// ORI: rd = rs1 | imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] - pub fn ori_constraint(row: &CpuTraceRow) -> M31 { - if row.is_ori == M31::ZERO { - return M31::ZERO; - } - - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let imm_full = row.imm; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut imm_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut or_check = M31::ZERO; - - for i in 0..31 { - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - imm_reconstructed += row.imm_bits[i] * pow2; - rd_reconstructed += row.or_bits[i] * pow2; - // OR logic: or_bit = a + b - ab - let expected_or = row.rs1_bits[i] + row.imm_bits[i] - row.rs1_bits[i] * row.imm_bits[i]; - or_check += row.or_bits[i] - expected_or; - } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); - let expected_or = row.rs1_bits[31] + row.imm_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; - or_check += row.or_bits[31] - expected_or; - - row.is_ori * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - or_check + - (rd_full - rd_reconstructed) - ) + pub fn ori_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_ori, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_ori, row.imm_lo, row.imm_hi, &row.imm_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_ori, row.rd_val_lo, row.rd_val_hi, &row.or_bits, &mut constraints); + Self::push_or_logic_constraints(row.is_ori, &row.rs1_bits, &row.imm_bits, &row.or_bits, &mut constraints); + constraints } /// XORI: rd = rs1 ^ imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] - pub fn xori_constraint(row: &CpuTraceRow) -> M31 { - if row.is_xori == M31::ZERO { - return M31::ZERO; - } - - let two_16 = M31::new(1 << 16); - let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; - let imm_full = row.imm; - let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - - let mut rs1_reconstructed = M31::ZERO; - let mut imm_reconstructed = M31::ZERO; - let mut rd_reconstructed = M31::ZERO; - let mut xor_check = M31::ZERO; - - for i in 0..31 { - let pow2 = M31::new(1 << i); - rs1_reconstructed += row.rs1_bits[i] * pow2; - imm_reconstructed += row.imm_bits[i] * pow2; - rd_reconstructed += row.xor_bits[i] * pow2; - // XOR logic: xor_bit = a + b - 2ab - let expected_xor = row.rs1_bits[i] + row.imm_bits[i] - M31::new(2) * row.rs1_bits[i] * row.imm_bits[i]; - xor_check += row.xor_bits[i] - expected_xor; - } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.imm_bits[31] - M31::new(2) * row.rs1_bits[31] * row.imm_bits[31]; - xor_check += row.xor_bits[31] - expected_xor; - - row.is_xori * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - xor_check + - (rd_full - rd_reconstructed) - ) + pub fn xori_constraint(row: &CpuTraceRow) -> Vec { + let mut constraints = Vec::with_capacity(134); + Self::push_word_bit_constraints(row.is_xori, row.rs1_val_lo, row.rs1_val_hi, &row.rs1_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_xori, row.imm_lo, row.imm_hi, &row.imm_bits, &mut constraints); + Self::push_word_bit_constraints(row.is_xori, row.rd_val_lo, row.rd_val_hi, &row.xor_bits, &mut constraints); + Self::push_xor_logic_constraints(row.is_xori, &row.rs1_bits, &row.imm_bits, &row.xor_bits, &mut constraints); + constraints } /// SLTI: rd = (rs1 < imm) ? 1 : 0 (signed). @@ -1486,9 +1345,9 @@ impl ConstraintEvaluator { constraints.push(sub_c1); constraints.push(sub_c2); - constraints.push(ConstraintEvaluator::and_constraint(row)); - constraints.push(ConstraintEvaluator::or_constraint(row)); - constraints.push(ConstraintEvaluator::xor_constraint(row)); + constraints.extend(ConstraintEvaluator::and_constraint(row)); + constraints.extend(ConstraintEvaluator::or_constraint(row)); + constraints.extend(ConstraintEvaluator::xor_constraint(row)); constraints.push(ConstraintEvaluator::sll_constraint(row)); constraints.push(ConstraintEvaluator::srl_constraint(row)); constraints.push(ConstraintEvaluator::sra_constraint(row)); @@ -1497,9 +1356,9 @@ impl ConstraintEvaluator { constraints.push(ConstraintEvaluator::signed_lt_constraint(row)); constraints.push(ConstraintEvaluator::addi_constraint(row)); - constraints.push(ConstraintEvaluator::andi_constraint(row)); - constraints.push(ConstraintEvaluator::ori_constraint(row)); - constraints.push(ConstraintEvaluator::xori_constraint(row)); + constraints.extend(ConstraintEvaluator::andi_constraint(row)); + constraints.extend(ConstraintEvaluator::ori_constraint(row)); + constraints.extend(ConstraintEvaluator::xori_constraint(row)); constraints.push(ConstraintEvaluator::slti_constraint(row)); constraints.push(ConstraintEvaluator::sltiu_constraint(row)); constraints.push(ConstraintEvaluator::slli_constraint(row)); @@ -1551,6 +1410,31 @@ impl ConstraintEvaluator { #[cfg(test)] mod tests { use super::*; + + fn set_word_limbs(lo: &mut M31, hi: &mut M31, value: u32) { + *lo = M31::new(value & 0xffff); + *hi = M31::new(value >> 16); + } + + fn set_bits(bits: &mut [M31; 32], value: u32) { + for (i, bit) in bits.iter_mut().enumerate() { + *bit = M31::new((value >> i) & 1); + } + } + + fn assert_all_zero(constraints: &[M31]) { + assert!( + constraints.iter().all(|c| *c == M31::ZERO), + "expected all constraints to be zero: {constraints:?}" + ); + } + + fn assert_any_nonzero(constraints: &[M31]) { + assert!( + constraints.iter().any(|c| *c != M31::ZERO), + "expected at least one non-zero constraint" + ); + } #[test] fn test_rv32im_air_creation() { @@ -1614,6 +1498,71 @@ mod tests { assert_eq!(c1, M31::ZERO); assert_eq!(c2, M31::ZERO); } + + #[test] + fn test_xor_constraint_rejects_m31_word_collision() { + let mut row = CpuTraceRow::default(); + row.is_xor = M31::ONE; + + set_word_limbs(&mut row.rs1_val_lo, &mut row.rs1_val_hi, 0x8000_0000); + set_word_limbs(&mut row.rs2_val_lo, &mut row.rs2_val_hi, 0); + set_word_limbs(&mut row.rd_val_lo, &mut row.rd_val_hi, 0x8000_0000); + + // Forged witnesses for bit value 1 collide with 0x8000_0000 when a full + // 32-bit word is reconstructed in M31, because 2^31 == 1 mod p. + row.rs1_bits[0] = M31::ONE; + row.xor_bits[0] = M31::ONE; + + assert_any_nonzero(&ConstraintEvaluator::xor_constraint(&row)); + } + + #[test] + fn test_bitwise_constraints_accept_rv32_high_bit() { + let mut xor_row = CpuTraceRow::default(); + xor_row.is_xor = M31::ONE; + set_word_limbs(&mut xor_row.rs1_val_lo, &mut xor_row.rs1_val_hi, 0x8000_0000); + set_word_limbs(&mut xor_row.rs2_val_lo, &mut xor_row.rs2_val_hi, 0); + set_word_limbs(&mut xor_row.rd_val_lo, &mut xor_row.rd_val_hi, 0x8000_0000); + set_bits(&mut xor_row.rs1_bits, 0x8000_0000); + set_bits(&mut xor_row.rs2_bits, 0); + set_bits(&mut xor_row.xor_bits, 0x8000_0000); + assert_all_zero(&ConstraintEvaluator::xor_constraint(&xor_row)); + + let mut or_row = CpuTraceRow::default(); + or_row.is_or = M31::ONE; + set_word_limbs(&mut or_row.rs1_val_lo, &mut or_row.rs1_val_hi, 0x8000_0000); + set_word_limbs(&mut or_row.rs2_val_lo, &mut or_row.rs2_val_hi, 0x0000_0001); + set_word_limbs(&mut or_row.rd_val_lo, &mut or_row.rd_val_hi, 0x8000_0001); + set_bits(&mut or_row.rs1_bits, 0x8000_0000); + set_bits(&mut or_row.rs2_bits, 0x0000_0001); + set_bits(&mut or_row.or_bits, 0x8000_0001); + assert_all_zero(&ConstraintEvaluator::or_constraint(&or_row)); + + let mut and_row = CpuTraceRow::default(); + and_row.is_and = M31::ONE; + set_word_limbs(&mut and_row.rs1_val_lo, &mut and_row.rs1_val_hi, 0x8000_0001); + set_word_limbs(&mut and_row.rs2_val_lo, &mut and_row.rs2_val_hi, 0x8000_0000); + set_word_limbs(&mut and_row.rd_val_lo, &mut and_row.rd_val_hi, 0x8000_0000); + set_bits(&mut and_row.rs1_bits, 0x8000_0001); + set_bits(&mut and_row.rs2_bits, 0x8000_0000); + set_bits(&mut and_row.and_bits, 0x8000_0000); + assert_all_zero(&ConstraintEvaluator::and_constraint(&and_row)); + } + + #[test] + fn test_xori_constraint_uses_immediate_limbs() { + let mut row = CpuTraceRow::default(); + row.is_xori = M31::ONE; + + set_word_limbs(&mut row.rs1_val_lo, &mut row.rs1_val_hi, 0); + set_word_limbs(&mut row.imm_lo, &mut row.imm_hi, 0xffff_ffff); + set_word_limbs(&mut row.rd_val_lo, &mut row.rd_val_hi, 0xffff_ffff); + set_bits(&mut row.rs1_bits, 0); + set_bits(&mut row.imm_bits, 0xffff_ffff); + set_bits(&mut row.xor_bits, 0xffff_ffff); + + assert_all_zero(&ConstraintEvaluator::xori_constraint(&row)); + } #[test] fn test_lui_constraint() { diff --git a/crates/prover/src/stark.rs b/crates/prover/src/stark.rs index 42d2477..d7a3e65 100644 --- a/crates/prover/src/stark.rs +++ b/crates/prover/src/stark.rs @@ -84,6 +84,12 @@ use crate::{ use zp1_primitives::{M31, QM31, CirclePoint}; use zp1_air::{CpuTraceRow, ConstraintEvaluator as AirConstraintEvaluator}; +fn constraint_alpha_count() -> usize { + // Four explicit boundary constraints, all intra-row AIR constraints, and + // one inter-row PC consistency constraint. + 4 + AirConstraintEvaluator::evaluate_all(&CpuTraceRow::default()).len() + 1 +} + /// Configuration for the STARK prover. #[derive(Clone, Debug)] pub struct StarkConfig { @@ -424,9 +430,8 @@ impl StarkProver { } /// Squeeze random coefficients for combining constraints. - fn squeeze_constraint_alphas(&mut self, num_cols: usize) -> Vec { - // Generate enough alphas for boundary + transition constraints - let num_constraints = num_cols * 2; // boundary + transition per column + fn squeeze_constraint_alphas(&mut self, _num_cols: usize) -> Vec { + let num_constraints = constraint_alpha_count(); (0..num_constraints) .map(|_| self.channel.squeeze_challenge()) .collect() @@ -854,9 +859,8 @@ impl StarkVerifier { // Absorb trace commitment channel.absorb(&proof.trace_commitment); - // Get constraint alphas (must match prover - use same count as num_cols * 2) - let num_cols = proof.ood_values.trace_at_z.len(); - let _constraint_alphas: Vec = (0..num_cols * 2) + // Get constraint alphas (must match prover). + let _constraint_alphas: Vec = (0..constraint_alpha_count()) .map(|_| channel.squeeze_challenge()) .collect();