diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1cae3847ca..c22c51035e 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -330,6 +330,16 @@ impl Air for NativeSumcheckAir { ) .eval(builder, header_row); + // Read max_round + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[0]), + [max_round], + first_timestamp + AB::F::from_canonical_usize(7), + &header_row_specific.read_records[7], + ) + .eval(builder, header_row); + // Write final result self.memory_bridge .write( @@ -348,20 +358,6 @@ impl Air for NativeSumcheckAir { let next_prod_row_specific: &ProdSpecificCols = next.specific[..ProdSpecificCols::::width()].borrow(); - self.memory_bridge - .read( - MemoryAddress::new( - native_as, - register_ptrs[0] - + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) - + (curr_prod_n - AB::F::ONE), - ), // curr_prod_n starts at 1. - [max_round], - start_timestamp, - &prod_row_specific.read_records[0], - ) - .eval(builder, prod_row); - // prod_row * within_round_limit = // prod_in_round_evaluation + prod_next_round_evaluation builder @@ -385,8 +381,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, - start_timestamp + AB::F::ONE, - &prod_row_specific.read_records[1], + start_timestamp, + &prod_row_specific.read_records[0], ) .eval(builder, prod_row * within_round_limit); @@ -402,7 +398,7 @@ impl Air for NativeSumcheckAir { register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::TWO, + start_timestamp + AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); @@ -449,21 +445,6 @@ impl Air for NativeSumcheckAir { let next_logup_row_specfic: &LogupSpecificCols = next.specific[..LogupSpecificCols::::width()].borrow(); - self.memory_bridge - .read( - MemoryAddress::new( - native_as, - register_ptrs[0] - + AB::F::from_canonical_usize(EXT_DEG * 2) - + num_prod_spec - + (curr_logup_n - AB::F::ONE), - ), // curr_logup_n starts at 1. - [max_round], - start_timestamp, - &logup_row_specific.read_records[0], - ) - .eval(builder, logup_row); - // logup_row * within_round_limit = // logup_in_round_evaluation + logup_next_round_evaluation builder @@ -488,8 +469,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, - start_timestamp + AB::F::ONE, - &logup_row_specific.read_records[1], + start_timestamp, + &logup_row_specific.read_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -513,7 +494,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::TWO, + start_timestamp + AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -528,7 +509,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::from_canonical_usize(3), + start_timestamp + AB::F::TWO, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index bb7cfa7080..873ba2de5a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -144,7 +144,6 @@ where let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) .map(|x: F| x.as_canonical_u32()); - let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = ctx; // allocate n rows @@ -198,19 +197,22 @@ where r_evals_reg.as_canonical_u32(), head_specific.read_records[4].as_mut(), ); - let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), head_specific.read_records[5].as_mut(), ); - let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - cur_timestamp += 7; // 5 register reads + ctx read + challenges read + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, + head_specific.read_records[7].as_mut() + ); + cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); // challenges = [alpha, c1=r, c2=1-r] @@ -221,7 +223,7 @@ where let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - // all rows share same register values, ctx, challenges + // all rows share same register values, ctx, challenges, max_round for row in rows.iter_mut() { // c1, c2 are same during the entire execution row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); @@ -236,6 +238,7 @@ where row.register_ptrs[2] = prod_evals_ptr; row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; + row.max_round = max_round; } // product rows @@ -256,15 +259,6 @@ where }; prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); - - // read max_round - let [max_round]: [F; 1] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, - prod_specific.read_records[0].as_mut(), - ); - cur_timestamp += 1; - prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); prod_row.max_round = max_round; @@ -285,7 +279,7 @@ where let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( state.memory, prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[1].as_mut(), + prod_specific.read_records[0].as_mut(), ); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -350,15 +344,7 @@ where }; logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); - - let [max_round]: [F; 1] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, - logup_specific.read_records[0].as_mut(), - ); logup_row.max_round = max_round; - cur_timestamp += 1; - let alpha_numerator = alpha_acc; let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); @@ -380,7 +366,7 @@ where let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[1].as_mut(), + logup_specific.read_records[0].as_mut(), ); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -545,7 +531,7 @@ impl TraceFiller for NativeSumcheckFiller { mem_fill_helper( mem_helper, start_timestamp + 1, - prod_row_specific.read_records[1].as_mut(), + prod_row_specific.read_records[0].as_mut(), ); // write p_eval mem_fill_helper( @@ -569,7 +555,7 @@ impl TraceFiller for NativeSumcheckFiller { mem_fill_helper( mem_helper, start_timestamp + 1, - logup_row_specific.read_records[1].as_mut(), + logup_row_specific.read_records[0].as_mut(), ); // write p_eval mem_fill_helper( diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index b3e6bf4f25..51eb6d39cf 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -92,8 +92,8 @@ pub struct NativeSumcheckCols { pub struct HeaderSpecificCols { pub pc: T, pub registers: [T; 5], - /// 5 register reads + ctx read + challenges read - pub read_records: [MemoryReadAuxCols; 7], + /// 5 register reads + ctx read + max round read + challenges read + pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } @@ -105,8 +105,8 @@ pub struct ProdSpecificCols { pub data_ptr: T, /// 2 extension elements pub p: [T; EXT_DEG * 2], - /// read max varibale and 2 p values - pub read_records: [MemoryReadAuxCols; 2], + /// read 2 p values + pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// write p_evals @@ -122,8 +122,8 @@ pub struct LogupSpecificCols { pub data_ptr: T, /// 4 extension elements pub pq: [T; EXT_DEG * 4], - /// read max variable and 4 values: p1, p2, q1, q2 - pub read_records: [MemoryReadAuxCols; 2], + /// read 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// Calculated q evals diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a475bf9e49..9eb51a3987 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -214,6 +214,8 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); @@ -224,10 +226,6 @@ unsafe fn execute_e12_impl( let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; for i in 0..num_prod_spec { - let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, prod_offset + i) - .map(|x: F| x.as_canonical_u32()); - let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, prod_specs_inner_len, @@ -266,10 +264,6 @@ unsafe fn execute_e12_impl( let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; for i in 0..num_logup_spec { - // read max_round - let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, logup_offset + i) - .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( logup_specs_inner_inner_len, logup_specs_inner_len,