Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions crates/lean_compiler/snark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ def parallel_range(a: int, b: int):
return range(a, b)


# dynamic_unroll(start, end, n_bits) returns range(start, end) for Python execution
def dynamic_unroll(start: int, end: int, n_bits: int):
_ = n_bits
return range(start, end)


# Array - simulates write-once memory with pointer arithmetic
class Array:
def __init__(self, size: int):
Expand Down Expand Up @@ -184,6 +178,10 @@ def hint_log2_ceil(n):
return log2_ceil(n)


def hint_div_floor(a, b, q_ptr, r_ptr):
_ = a, b, q_ptr, r_ptr


def hint_witness(name, destination):
"""Write the next witness entry for `name` into `destination`."""
_ = (name, destination)
252 changes: 4 additions & 248 deletions crates/lean_compiler/src/a_simplify_lang/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use crate::{
CompilationFlags, F,
a_simplify_lang::post_optimization::propagate_copies,
lang::*,
parser::{ConstArrayValue, parse_program},
};
use crate::{F, a_simplify_lang::post_optimization::propagate_copies, lang::*, parser::ConstArrayValue};
use backend::PrimeCharacteristicRing;
use lean_vm::{
ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName,
Expand Down Expand Up @@ -755,43 +750,6 @@ fn compile_time_transform_in_lines(
}
}

Line::ForLoop {
loop_kind: LoopKind::DynamicUnroll { n_bits },
iterator,
start,
end,
body,
location,
} => {
let Some(start_val) = start.compile_time_eval(const_arrays, &vector_len_tracker) else {
return Err(format!(
"line {location}: dynamic_unroll start must be a compile-time constant"
));
};
let start_val = start_val.to_usize();
let Some(n_bits_val) = n_bits.compile_time_eval(const_arrays, &vector_len_tracker) else {
return Err(format!(
"line {location}: dynamic_unroll n_bits must be a compile-time constant"
));
};
let n_bits_val = n_bits_val.to_usize();
if n_bits_val < 1 {
return Err(format!(
"line {location}: dynamic_unroll n_bits must be >= 1, got {n_bits_val}"
));
}
let expanded = expand_dynamic_unroll(
&iterator.clone(),
&end.clone(),
n_bits_val,
start_val,
&body.clone(),
*location,
unroll_counter,
);
lines.splice(i..=i, expanded);
continue;
}
Line::ForLoop {
iterator,
start,
Expand Down Expand Up @@ -828,7 +786,7 @@ fn compile_time_transform_in_lines(
Line::IfCondition { .. }
| Line::Match { .. }
| Line::ForLoop {
loop_kind: LoopKind::Unroll | LoopKind::DynamicUnroll { .. },
loop_kind: LoopKind::Unroll,
..
}
) {
Expand Down Expand Up @@ -1578,14 +1536,11 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) {
start,
end,
body,
loop_kind,
loop_kind: _,
location: _,
} => {
check_expr_scoping(start, ctx);
check_expr_scoping(end, ctx);
if let LoopKind::DynamicUnroll { n_bits } = loop_kind {
check_expr_scoping(n_bits, ctx);
}
let mut new_scope_vars = BTreeSet::new();
new_scope_vars.insert(iterator.clone());
ctx.scopes.push(Scope { vars: new_scope_vars });
Expand Down Expand Up @@ -2778,7 +2733,7 @@ fn simplify_lines(
} => {
assert!(
matches!(loop_kind, LoopKind::Range | LoopKind::ParallelRange),
"Unrolled/dynamic_unroll loops should have been handled already"
"Unrolled loops should have been handled already"
);

let is_parallel = loop_kind.is_parallel();
Expand Down Expand Up @@ -3253,7 +3208,6 @@ pub fn find_variable_usage(
start,
end,
body,
loop_kind,
..
} => {
let (body_internal, body_external) = find_variable_usage(body, const_arrays);
Expand All @@ -3262,9 +3216,6 @@ pub fn find_variable_usage(
external_vars.extend(body_external.difference(&internal_vars).cloned());
on_new_expr(start, &internal_vars, &mut external_vars);
on_new_expr(end, &internal_vars, &mut external_vars);
if let LoopKind::DynamicUnroll { n_bits } = loop_kind {
on_new_expr(n_bits, &internal_vars, &mut external_vars);
}
}
Line::Panic { .. } | Line::LocationReport { .. } => {}
Line::VecDeclaration { var, elements, .. } => {
Expand Down Expand Up @@ -3703,201 +3654,6 @@ fn replace_vars_for_unroll(
transform_vars_in_lines(lines, &transform);
}

/// Chunk size threshold for splitting large unrolls into hybrid loops.
/// Bits k where 2^k > CHUNK_SIZE will use a runtime outer loop with CHUNK_SIZE inner unroll.
const DYNAMIC_UNROLL_CHUNK_SIZE: usize = 1 << 9; // 512

/// Expands `for idx in dynamic_unroll(start, a, n_bits): body` into:
/// 1. Bit decomposition of `a - start` (with constraints)
/// 2. Conditional execution of `body` for each index start..a
///
/// Computes `n_iters = end - start_val`, decomposes into bits, and offsets
/// each iterator value by the compile-time `start_val`.
///
/// The expansion template is written in zkDSL for readability, then parsed
/// and post-processed (variable renaming, body splicing, location fixup).
fn expand_dynamic_unroll(
iterator: &Var,
runtime_end: &Expression,
n_bits: usize,
start_val: usize,
body: &[Line],
location: SourceLocation,
unroll_counter: &mut Counter,
) -> Vec<Line> {
let id = unroll_counter.get_next();
let pfx = format!("@du{id}");
let ps_len = n_bits + 1;

// The template is the zkDSL expansion of dynamic_unroll, with `end` as the
// runtime bound and `__iter` as a placeholder for the iterator assignment.
//
// Bits are stored in big-endian order: bits[0] is the most significant bit
// (weight 2^(n_bits-1)), bits[n_bits-1] is the least significant (weight 2^0).
// ps has n_bits+1 elements: ps[0]=0, ps[k+1] = ps[k] + bits[k]*2^(n_bits-1-k).
// So ps[k] is the offset (number of indices below bit k), and ps[n_bits] == n_iters.
//
// For large bits (block_size > CHUNK_SIZE), we split into chunks to reduce bytecode size:
// - outer runtime loop over n_chunks = block_size / CHUNK_SIZE
// - inner unroll over CHUNK_SIZE iterations
// For small bits, the range loop has minimal overhead.

// Build the template with per-bit chunking logic.
// Pre-compute __base_k = start_val + ps[k] once per activated bit,
// so the inner loop stays at 1 ADD per iteration.
let mut loop_body = String::new();
for k in 0..n_bits {
let block_size = 1usize << (n_bits - 1 - k);
if block_size <= DYNAMIC_UNROLL_CHUNK_SIZE {
// Small block: fully unroll
loop_body.push_str(&format!(
r#"
if bits[{k}] == 1:
__base_{k} = {start_val} + ps[{k}]
for j in unroll(0, {block_size}):
__iter = __base_{k} + j
"#
));
} else {
// Large block: hybrid loop (runtime outer, unroll inner)
// Use an offset variable to avoid MUL per iteration
let n_chunks = block_size / DYNAMIC_UNROLL_CHUNK_SIZE;
loop_body.push_str(&format!(
r#"
if bits[{k}] == 1:
__offset_{k}: Mut = {start_val} + ps[{k}]
for chunk in range(0, {n_chunks}):
for j in unroll(0, {DYNAMIC_UNROLL_CHUNK_SIZE}):
__iter = __offset_{k} + j
__offset_{k} = __offset_{k} + {DYNAMIC_UNROLL_CHUNK_SIZE}
"#
));
}
}
let template = format!(
r#"
def __dynamic_unroll_template(end):
n_iters = end - {start_val}
bits = Array({n_bits})
hint_decompose_bits(n_iters, bits, {n_bits})
ps = Array({ps_len})
ps[0] = 0
for k in unroll(0, {n_bits}):
b = bits[k]
assert b * (1 - b) == 0
ps[k + 1] = ps[k] + b * 2**({n_bits} - 1 - k)
assert n_iters == ps[{n_bits}]
{loop_body}
return
"#
);

let program = parse_program(&crate::ProgramSource::Raw(template), CompilationFlags::default()).unwrap();
assert_eq!(program.functions.len(), 1);
let func = program.functions.values().next().unwrap();
let mut lines = func.body.clone();

// Strip trailing return + its LocationReport
while matches!(
lines.last(),
Some(Line::FunctionRet { .. } | Line::LocationReport { .. })
) {
lines.pop();
}
// Strip LocationReport lines (they carry template line numbers, not the real ones)
strip_location_reports(&mut lines);

// Rename all internal variables with @du{id}_ prefix.
// __iter is renamed directly to the user's iterator variable.
let internals: BTreeSet<String> = ["bits", "ps", "k", "j", "b", "chunk", "n_iters"]
.iter()
.map(|s| s.to_string())
.collect();
transform_vars_in_lines(&mut lines, &|var: &Var| {
if var == "__iter" {
VarTransform::Rename(iterator.clone())
} else if var == "end" || internals.contains(var) || var.starts_with("__offset_") || var.starts_with("__base_")
{
VarTransform::Rename(format!("{pfx}_{var}"))
} else {
VarTransform::Keep
}
});

// Prepend: @du{id}_end = runtime_end
lines.insert(
0,
Line::Statement {
targets: vec![AssignmentTarget::Var {
var: format!("{pfx}_end"),
is_mutable: false,
}],
value: runtime_end.clone(),
location,
},
);

// Insert body after every `{iterator} = ...` assignment (the renamed __iter lines)
insert_body_after_var(&mut lines, iterator, body);

// Fix all source locations to point to the actual dynamic_unroll call site
set_locations_recursive(&mut lines, location);

lines
}

fn strip_location_reports(lines: &mut Vec<Line>) {
lines.retain(|l| !matches!(l, Line::LocationReport { .. }));
for line in lines.iter_mut() {
for block in line.nested_blocks_mut() {
strip_location_reports(block);
}
}
}

/// In every nested block, insert `body` lines after each statement that assigns to `var`.
fn insert_body_after_var(lines: &mut [Line], var: &str, body: &[Line]) {
for line in lines.iter_mut() {
for block in line.nested_blocks_mut() {
let mut i = 0;
while i < block.len() {
if matches!(&block[i], Line::Statement { targets, .. }
if targets.iter().any(|t| matches!(t, AssignmentTarget::Var { var: v, .. } if v == var)))
{
let insert_pos = i + 1;
for (j, body_line) in body.iter().enumerate() {
block.insert(insert_pos + j, body_line.clone());
}
i += 1 + body.len();
} else {
i += 1;
}
}
insert_body_after_var(block, var, body);
}
}
}

fn set_locations_recursive(lines: &mut [Line], location: SourceLocation) {
for line in lines {
match line {
Line::Statement { location: loc, .. }
| Line::Assert { location: loc, .. }
| Line::IfCondition { location: loc, .. }
| Line::ForLoop { location: loc, .. }
| Line::Match { location: loc, .. }
| Line::LocationReport { location: loc }
| Line::VecDeclaration { location: loc, .. }
| Line::Push { location: loc, .. }
| Line::Pop { location: loc, .. } => *loc = location,
Line::ForwardDeclaration { .. } | Line::FunctionRet { .. } | Line::Panic { .. } => {}
}
for block in line.nested_blocks_mut() {
set_locations_recursive(block, location);
}
}
}

fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>) {
match expr {
Expression::Value(value) => match &value {
Expand Down
3 changes: 1 addition & 2 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,10 @@ elif_clause = { "elif" ~ condition ~ ":" ~ newline ~ statement* ~ end_block }

else_clause = { "else" ~ ":" ~ newline ~ statement* ~ end_block }

for_statement = { "for" ~ identifier ~ "in" ~ (dynamic_unroll_range | unroll_range | parallel_range | range) ~ ":" ~ newline ~ statement* ~ end_block }
for_statement = { "for" ~ identifier ~ "in" ~ (unroll_range | parallel_range | range) ~ ":" ~ newline ~ statement* ~ end_block }
range = { "range" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
parallel_range = { "parallel_range" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
unroll_range = { "unroll" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
dynamic_unroll_range = { "dynamic_unroll" ~ "(" ~ expression ~ "," ~ expression ~ "," ~ expression ~ ")" }

match_statement = { "match" ~ expression ~ ":" ~ newline ~ match_arm* ~ end_block }
match_arm = { "case" ~ pattern ~ ":" ~ newline ~ statement* ~ end_block }
Expand Down
Loading
Loading