Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 17 additions & 36 deletions extensions/native/circuit/src/sumcheck/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,16 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
)
.eval(builder, header_row);

// Read max_round
self.memory_bridge
.read(
MemoryAddress::new(native_as, register_ptrs[0]),
[max_round],
first_timestamp + AB::F::from_canonical_usize(7),
&header_row_specific.read_records[7],
)
.eval(builder, header_row);

// Write final result
self.memory_bridge
.write(
Expand All @@ -348,20 +358,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
let next_prod_row_specific: &ProdSpecificCols<AB::Var> =
next.specific[..ProdSpecificCols::<AB::Var>::width()].borrow();

self.memory_bridge
.read(
MemoryAddress::new(
native_as,
register_ptrs[0]
+ AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN)
+ (curr_prod_n - AB::F::ONE),
), // curr_prod_n starts at 1.
[max_round],
start_timestamp,
&prod_row_specific.read_records[0],
)
.eval(builder, prod_row);

// prod_row * within_round_limit =
// prod_in_round_evaluation + prod_next_round_evaluation
builder
Expand All @@ -385,8 +381,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
.read(
MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr),
prod_row_specific.p,
start_timestamp + AB::F::ONE,
&prod_row_specific.read_records[1],
start_timestamp,
&prod_row_specific.read_records[0],
)
.eval(builder, prod_row * within_round_limit);

Expand All @@ -402,7 +398,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG),
),
prod_row_specific.p_evals,
start_timestamp + AB::F::TWO,
start_timestamp + AB::F::ONE,
&prod_row_specific.write_record,
)
.eval(builder, prod_row * within_round_limit);
Expand Down Expand Up @@ -449,21 +445,6 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
let next_logup_row_specfic: &LogupSpecificCols<AB::Var> =
next.specific[..LogupSpecificCols::<AB::Var>::width()].borrow();

self.memory_bridge
.read(
MemoryAddress::new(
native_as,
register_ptrs[0]
+ AB::F::from_canonical_usize(EXT_DEG * 2)
+ num_prod_spec
+ (curr_logup_n - AB::F::ONE),
), // curr_logup_n starts at 1.
[max_round],
start_timestamp,
&logup_row_specific.read_records[0],
)
.eval(builder, logup_row);

// logup_row * within_round_limit =
// logup_in_round_evaluation + logup_next_round_evaluation
builder
Expand All @@ -488,8 +469,8 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
.read(
MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr),
logup_row_specific.pq,
start_timestamp + AB::F::ONE,
&logup_row_specific.read_records[1],
start_timestamp,
&logup_row_specific.read_records[0],
)
.eval(builder, logup_row * within_round_limit);

Expand All @@ -513,7 +494,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
+ (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.p_evals,
start_timestamp + AB::F::TWO,
start_timestamp + AB::F::ONE,
&logup_row_specific.write_records[0],
)
.eval(builder, logup_row * within_round_limit);
Expand All @@ -528,7 +509,7 @@ impl<AB: InteractionBuilder> Air<AB> for NativeSumcheckAir {
* AB::F::from_canonical_usize(EXT_DEG),
),
logup_row_specific.q_evals,
start_timestamp + AB::F::from_canonical_usize(3),
start_timestamp + AB::F::TWO,
&logup_row_specific.write_records[1],
)
.eval(builder, logup_row * within_round_limit);
Expand Down
38 changes: 12 additions & 26 deletions extensions/native/circuit/src/sumcheck/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ where
let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32());
let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32())
.map(|x: F| x.as_canonical_u32());

let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] =
ctx;
// allocate n rows
Expand Down Expand Up @@ -198,19 +197,22 @@ where
r_evals_reg.as_canonical_u32(),
head_specific.read_records[4].as_mut(),
);

let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper(
state.memory,
ctx_ptr.as_canonical_u32(),
head_specific.read_records[5].as_mut(),
);

let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper(
state.memory,
challenges_ptr.as_canonical_u32(),
head_specific.read_records[6].as_mut(),
);
cur_timestamp += 7; // 5 register reads + ctx read + challenges read
let [max_round]: [F; 1] = tracing_read_native_helper(
state.memory,
ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32,
head_specific.read_records[7].as_mut()
);
cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read
head_row.challenges.copy_from_slice(&challenges);

// challenges = [alpha, c1=r, c2=1-r]
Expand All @@ -221,7 +223,7 @@ where
let mut eval_acc = elem_to_ext(F::from_canonical_u32(0));
let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1));

// all rows share same register values, ctx, challenges
// all rows share same register values, ctx, challenges, max_round
for row in rows.iter_mut() {
// c1, c2 are same during the entire execution
row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]);
Expand All @@ -236,6 +238,7 @@ where
row.register_ptrs[2] = prod_evals_ptr;
row.register_ptrs[3] = logup_evals_ptr;
row.register_ptrs[4] = r_evals_ptr;
row.max_round = max_round;
}

// product rows
Expand All @@ -256,15 +259,6 @@ where
};
prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1
prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp);

// read max_round
let [max_round]: [F; 1] = tracing_read_native_helper(
state.memory,
ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32,
prod_specific.read_records[0].as_mut(),
);
cur_timestamp += 1;

prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc);
prod_row.max_round = max_round;

Expand All @@ -285,7 +279,7 @@ where
let ps: [F; EXT_DEG * 2] = tracing_read_native_helper(
state.memory,
prod_evals_ptr.as_canonical_u32() + start,
prod_specific.read_records[1].as_mut(),
prod_specific.read_records[0].as_mut(),
);
let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap();
let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap();
Expand Down Expand Up @@ -350,15 +344,7 @@ where
};
logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1
logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp);

let [max_round]: [F; 1] = tracing_read_native_helper(
state.memory,
ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32,
logup_specific.read_records[0].as_mut(),
);
logup_row.max_round = max_round;
cur_timestamp += 1;

let alpha_numerator = alpha_acc;
let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha);
logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc);
Expand All @@ -380,7 +366,7 @@ where
let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper(
state.memory,
logup_evals_ptr.as_canonical_u32() + start,
logup_specific.read_records[1].as_mut(),
logup_specific.read_records[0].as_mut(),
);
let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap();
let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap();
Expand Down Expand Up @@ -545,7 +531,7 @@ impl<F: PrimeField32> TraceFiller<F> for NativeSumcheckFiller {
mem_fill_helper(
mem_helper,
start_timestamp + 1,
prod_row_specific.read_records[1].as_mut(),
prod_row_specific.read_records[0].as_mut(),
);
// write p_eval
mem_fill_helper(
Expand All @@ -569,7 +555,7 @@ impl<F: PrimeField32> TraceFiller<F> for NativeSumcheckFiller {
mem_fill_helper(
mem_helper,
start_timestamp + 1,
logup_row_specific.read_records[1].as_mut(),
logup_row_specific.read_records[0].as_mut(),
);
// write p_eval
mem_fill_helper(
Expand Down
12 changes: 6 additions & 6 deletions extensions/native/circuit/src/sumcheck/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ pub struct NativeSumcheckCols<T> {
pub struct HeaderSpecificCols<T> {
pub pc: T,
pub registers: [T; 5],
/// 5 register reads + ctx read + challenges read
pub read_records: [MemoryReadAuxCols<T>; 7],
/// 5 register reads + ctx read + max round read + challenges read
pub read_records: [MemoryReadAuxCols<T>; 8],
/// Write the final evaluation
pub write_records: MemoryWriteAuxCols<T, EXT_DEG>,
}
Expand All @@ -105,8 +105,8 @@ pub struct ProdSpecificCols<T> {
pub data_ptr: T,
/// 2 extension elements
pub p: [T; EXT_DEG * 2],
/// read max varibale and 2 p values
pub read_records: [MemoryReadAuxCols<T>; 2],
/// read 2 p values
pub read_records: [MemoryReadAuxCols<T>; 1],
/// Calculated p evals
pub p_evals: [T; EXT_DEG],
/// write p_evals
Expand All @@ -122,8 +122,8 @@ pub struct LogupSpecificCols<T> {
pub data_ptr: T,
/// 4 extension elements
pub pq: [T; EXT_DEG * 4],
/// read max variable and 4 values: p1, p2, q1, q2
pub read_records: [MemoryReadAuxCols<T>; 2],
/// read 4 values: p1, p2, q1, q2
pub read_records: [MemoryReadAuxCols<T>; 1],
/// Calculated p evals
pub p_evals: [T; EXT_DEG],
/// Calculated q evals
Expand Down
10 changes: 2 additions & 8 deletions extensions/native/circuit/src/sumcheck/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
ctx;
let challenges: [F; EXT_DEG * 4] =
exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32());
let [max_round]: [u32; 1] = exec_state
.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32);
let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap();
let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap();
let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap();
Expand All @@ -224,10 +226,6 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(

let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32;
for i in 0..num_prod_spec {
let [max_round]: [u32; 1] = exec_state
.vm_read(NATIVE_AS, prod_offset + i)
.map(|x: F| x.as_canonical_u32());

let start = calculate_3d_ext_idx(
prod_specs_inner_inner_len,
prod_specs_inner_len,
Expand Down Expand Up @@ -266,10 +264,6 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(

let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec;
for i in 0..num_logup_spec {
// read max_round
let [max_round]: [u32; 1] = exec_state
.vm_read(NATIVE_AS, logup_offset + i)
.map(|x: F| x.as_canonical_u32());
let start = calculate_3d_ext_idx(
logup_specs_inner_inner_len,
logup_specs_inner_len,
Expand Down
Loading