diff --git a/diffsl/benches/evaluation.rs b/diffsl/benches/evaluation.rs index d42bf78..380282a 100644 --- a/diffsl/benches/evaluation.rs +++ b/diffsl/benches/evaluation.rs @@ -55,7 +55,7 @@ fn execute( let n = N; let compiler = setup::(n, f_text, "execute"); let mut data = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice()); + compiler.set_inputs(&[], data.as_mut_slice(), 0); let mut u = vec![1.0; n]; compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); let mut rr = vec![0.0; n]; diff --git a/diffsl/src/ast/mod.rs b/diffsl/src/ast/mod.rs index 5ea2433..e369b4f 100644 --- a/diffsl/src/ast/mod.rs +++ b/diffsl/src/ast/mod.rs @@ -580,10 +580,18 @@ impl<'a> Ast<'a> { } AstKind::CallArg(arg) => Self::new_call_arg(arg.name, arg.expression.tangent()), AstKind::Name(name) => { - if name.name == "t" { + if name.name == "t" || name.name == "N" { Self::new_number(0.0) } else { - Self::new_name(name.name, name.indices.clone(), true) + Ast { + kind: AstKind::Name(Name { + name: name.name, + indices: name.indices.clone(), + indice: name.indice.clone(), + is_tangent: true, + }), + span: self.span, + } } } AstKind::Number(_) => Self::new_number(0.0), @@ -883,9 +891,10 @@ impl<'a> Ast<'a> { AstKind::Name(found_name) => { // if the name is indexed by a single indice, don't add that index to the list if let Some(indice) = &found_name.indice { - let indice = indice.as_ref().kind.as_indice().unwrap(); - if indice.sep.is_none() || indice.last.is_none() { - return; + if let Some(indice) = indice.as_ref().kind.as_indice() { + if indice.sep.is_none() || indice.last.is_none() { + return; + } } } indices.extend(found_name.indices.iter().cloned()); diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index 5249d70..a7a95c8 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -132,6 +132,7 @@ impl<'s> DiscreteModel<'s> { let reserved_names = [ "u0", "t", + "N", "data", "root", "thread_id", @@ -502,7 +503,11 @@ impl<'s> DiscreteModel<'s> { let dependent_on_time = env_entry.is_time_dependent(); let dependent_on_dudt = env_entry.is_dstatedt_dependent(); let dependent_on_input = env_entry.is_input_dependent(); - if !dependent_on_time && !dependent_on_input { + let dependent_on_model = env_entry.is_model_dependent(); + if !dependent_on_time + && !dependent_on_input + && !dependent_on_model + { ret.constant_defns.push(built); } else if !dependent_on_time { ret.input_dep_defns.push(built); @@ -1401,6 +1406,7 @@ mod tests { tensor_fail_tests!( error_scalar: "r {1, 2}" errors ["cannot have more than one element in a scalar",], + error_reserved_name_n: "N { 1 }" errors ["N is a reserved name",], error_cannot_find: "r { k }" errors ["cannot find variable k",], error_different_shape: "a_i { 1, 2 } b_i { 1, 2, 3 } c_i { a_i + b_i }" errors ["cannot broadcast shapes: [2], [3]",], too_many_indices: "A_i { 1, 2 } B_i { (0:2): A_ij }" errors ["too many permutation indices",], @@ -1447,6 +1453,9 @@ mod tests { index: "A_i { 0.0, 1.0, 2.0 } B { A_i[1] }" expect "B" = "B { (): A_i[1] }", index2: "A_i { 0.0, 1.0, 2.0 } B_i { A_i[1:3] }" expect "B" = "B_i(2) { (0)(2):A_i[1:3](2) }", index3: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[0:2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[0:2] (2, 2) }", + index_n_scalar: "A_i { 0.0, 1.0, 2.0 } B { A_i[N % 2] }" expect "B" = "B { (): A_i[N % 2] }", + index_n_range: "A_i { 0.0, 1.0, 2.0, 3.0 } B_i { A_i[N:N+2] }" expect "B" = "B_i(2) { (0)(2):A_i[N:N + 2](2) }", + index_n_range_contraction: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[N:N+2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[N:N + 2] (2, 2) }", prefix_minus: "A { 1.0 / -2.0 }" expect "A" = "A { (): 1 / (-2) }", time: "A_i { t }" expect "A" = "A_i (1) { (0)(1): t }", named_blk: "A_i { (0:3): y = 1, 2 }" expect "A" = "A_i (4) { (0)(3): y = 1, (3)(1): 2 }", diff --git a/diffsl/src/discretise/env.rs b/diffsl/src/discretise/env.rs index 46b5ebe..f6c08bb 100644 --- a/diffsl/src/discretise/env.rs +++ b/diffsl/src/discretise/env.rs @@ -19,6 +19,7 @@ pub struct EnvVar { is_state_dependent: bool, is_dstatedt_dependent: bool, is_input_dependent: bool, + is_model_dependent: bool, is_algebraic: bool, } @@ -43,6 +44,10 @@ impl EnvVar { self.is_input_dependent } + pub fn is_model_dependent(&self) -> bool { + self.is_model_dependent + } + pub fn layout(&self) -> &Layout { self.layout.as_ref() } @@ -57,6 +62,127 @@ pub struct Env { } impl Env { + fn eval_const_integer_expr(expr: &Ast) -> Option { + match &expr.kind { + AstKind::Integer(v) => Some(*v), + AstKind::Number(v) => { + if v.fract() == 0.0 { + Some(*v as i64) + } else { + None + } + } + AstKind::Monop(op) => { + let child = Self::eval_const_integer_expr(op.child.as_ref())?; + match op.op { + '+' => Some(child), + '-' => child.checked_neg(), + _ => None, + } + } + AstKind::Binop(op) => { + let left = Self::eval_const_integer_expr(op.left.as_ref())?; + let right = Self::eval_const_integer_expr(op.right.as_ref())?; + match op.op { + '+' => left.checked_add(right), + '-' => left.checked_sub(right), + '*' => left.checked_mul(right), + '/' => { + if right == 0 { + None + } else { + Some(left / right) + } + } + '%' => { + if right == 0 { + None + } else { + Some(left % right) + } + } + _ => None, + } + } + _ => None, + } + } + + fn eval_integer_expr_with_n(expr: &Ast, n: i64) -> Option { + match &expr.kind { + AstKind::Integer(v) => Some(*v), + AstKind::Number(v) => { + if v.fract() == 0.0 { + Some(*v as i64) + } else { + None + } + } + AstKind::Name(name) => { + if name.name == "N" { + Some(n) + } else { + None + } + } + AstKind::Monop(op) => { + let child = Self::eval_integer_expr_with_n(op.child.as_ref(), n)?; + match op.op { + '+' => Some(child), + '-' => child.checked_neg(), + _ => None, + } + } + AstKind::Binop(op) => { + let left = Self::eval_integer_expr_with_n(op.left.as_ref(), n)?; + let right = Self::eval_integer_expr_with_n(op.right.as_ref(), n)?; + match op.op { + '+' => left.checked_add(right), + '-' => left.checked_sub(right), + '*' => left.checked_mul(right), + '/' => { + if right == 0 { + None + } else { + Some(left / right) + } + } + '%' => { + if right == 0 { + None + } else { + Some(left % right) + } + } + _ => None, + } + } + _ => None, + } + } + + fn eval_constant_range_width(start: &Ast, end: &Ast) -> Option { + if let (Some(first), Some(last)) = ( + Self::eval_const_integer_expr(start), + Self::eval_const_integer_expr(end), + ) { + return last.checked_sub(first); + } + + let mut width = None; + for n in [0_i64, 1, 2, 3, 7, 16] { + let first_n = Self::eval_integer_expr_with_n(start, n)?; + let last_n = Self::eval_integer_expr_with_n(end, n)?; + let width_n = last_n.checked_sub(first_n)?; + match width { + Some(prev) if prev != width_n => return None, + Some(_) => {} + None => width = Some(width_n), + } + } + width + } + pub fn new() -> Self { let mut vars = HashMap::new(); vars.insert( @@ -67,6 +193,19 @@ impl Env { is_state_dependent: false, is_dstatedt_dependent: false, is_input_dependent: false, + is_model_dependent: false, + is_algebraic: true, + }, + ); + vars.insert( + "N".to_string(), + EnvVar { + layout: ArcLayout::new(Layout::new_scalar()), + is_time_dependent: false, + is_state_dependent: false, + is_dstatedt_dependent: false, + is_input_dependent: false, + is_model_dependent: true, is_algebraic: true, }, ); @@ -109,6 +248,19 @@ impl Env { self.is_tensor_dependent_on(tensor, "in") } + pub fn is_tensor_model_dependent(&self, tensor: &Tensor) -> bool { + if tensor.name() == "N" { + return true; + } + tensor.elmts().iter().any(|block| { + block + .expr() + .get_dependents() + .iter() + .any(|&dep| dep == "N" || self.vars[dep].is_model_dependent()) + }) + } + pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool { self.is_tensor_dependent_on(tensor, "dudt") } @@ -141,6 +293,7 @@ impl Env { is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), is_input_dependent: self.is_tensor_input_dependent(var), + is_model_dependent: self.is_tensor_model_dependent(var), }, ); } @@ -155,6 +308,7 @@ impl Env { is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), is_input_dependent: self.is_tensor_input_dependent(var), + is_model_dependent: self.is_tensor_model_dependent(var), }, ); } @@ -280,32 +434,79 @@ impl Env { // if the indice is a single integer then the resulting layout is a scalar if indice.sep.is_none() { let mut new_layout = Layout::new_scalar(); - let first = indice.first.kind.as_integer().unwrap(); - let last = first + 1; + let (first, last) = + if let Some(first) = Self::eval_const_integer_expr(indice.first.as_ref()) { + (first, first + 1) + } else { + // Dynamic integer indices (e.g. involving N) cannot be resolved at compile time. + // Conservatively keep dependencies for the full 1D extent. + let axis = layout_permuted + .shape() + .iter() + .position(|&d| d != 1) + .unwrap_or(0); + let dim = *layout_permuted.shape().get(axis).unwrap_or(&1); + (0, i64::try_from(dim).unwrap()) + }; new_layout.filter_deps_from(layout_permuted, first, last); return Some(new_layout); } else { // if the indice is a range then the resulting layout is a dense layout with shape given by the range // along the only non-unit dimension of the variable - let first = indice.first.kind.as_integer().unwrap(); - let last = indice.last.as_ref().unwrap().kind.as_integer().unwrap(); + let end_expr = indice.last.as_ref().unwrap().as_ref(); + let Some(width) = Self::eval_constant_range_width(indice.first.as_ref(), end_expr) + else { + self.errs.push(ValidationError::new( + "range indice width must be an integer constant (independent of N)" + .to_string(), + ast.span, + )); + return None; + }; // make sure the range is valid - if last < first { + if width < 0 { self.errs.push(ValidationError::new( - format!( - "invalid range indice: start {} is greater than end {}", - first, last - ), + format!("invalid range indice: width {} is negative", width), ast.span, )); return None; } - let dim = usize::try_from(last - first).unwrap(); + let dim = usize::try_from(width).unwrap(); let shape = layout_permuted .shape() .map(|&d| if d != 1 { dim } else { 1 }); let mut new_layout = Layout::new_dense(Shape::from(shape)); - new_layout.filter_deps_from(layout_permuted, first, last); + let first_const = Self::eval_const_integer_expr(indice.first.as_ref()); + if let Some(first) = first_const { + let last = first + width; + new_layout.filter_deps_from(layout_permuted, first, last); + } else if dim != 0 { + // For dynamic starts (e.g. N:N+2), union dependencies over all valid windows. + let axis = layout_permuted + .shape() + .iter() + .position(|&d| d != 1) + .unwrap_or(0); + let source_dim = + i64::try_from(*layout_permuted.shape().get(axis).unwrap_or(&1)).unwrap(); + let max_start = source_dim.saturating_sub(width); + let mut merged_layout: Option = None; + for start in 0..=max_start { + let mut window_layout = Layout::new_dense(new_layout.shape().clone()); + window_layout.filter_deps_from( + layout_permuted.clone(), + start, + start + width, + ); + merged_layout = Some(match merged_layout { + Some(accumulated) => accumulated.union(window_layout), + None => window_layout, + }); + } + if let Some(merged_layout) = merged_layout { + new_layout = merged_layout; + } + } return Some(new_layout); } } diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index 32c5461..bf76dd0 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -796,10 +796,10 @@ impl Compiler { ) } - pub fn set_inputs(&self, inputs: &[T], data: &mut [T]) { + pub fn set_inputs(&self, inputs: &[T], data: &mut [T], model_index: u32) { self.check_inputs_len(inputs, "inputs"); self.check_data_len(data, "data"); - unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr()) }; + unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr(), model_index) }; } pub fn get_inputs(&self, inputs: &mut [T], data: &[T]) { @@ -808,7 +808,14 @@ impl Compiler { unsafe { (self.jit_functions.get_inputs)(inputs.as_mut_ptr(), data.as_ptr()) }; } - pub fn set_inputs_grad(&self, inputs: &[T], dinputs: &[T], data: &[T], ddata: &mut [T]) { + pub fn set_inputs_grad( + &self, + inputs: &[T], + dinputs: &[T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -819,11 +826,19 @@ impl Compiler { dinputs.as_ptr(), data.as_ptr(), ddata.as_mut_ptr(), + model_index, ) }; } - pub fn set_inputs_rgrad(&self, inputs: &[T], dinputs: &mut [T], data: &[T], ddata: &mut [T]) { + pub fn set_inputs_rgrad( + &self, + inputs: &[T], + dinputs: &mut [T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -838,6 +853,7 @@ impl Compiler { dinputs.as_mut_ptr(), data.as_ptr(), ddata.as_mut_ptr(), + model_index, ) }; } @@ -947,7 +963,7 @@ mod tests { assert_relative_eq!(a2[0], T::zero()); // set the inputs and u0 let inputs = vec![T::one()]; - compiler.set_inputs(&inputs, data.as_mut_slice()); + compiler.set_inputs(&inputs, data.as_mut_slice(), 0); let mut u0 = vec![T::zero()]; compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); // now a and a2 should be set @@ -1054,6 +1070,215 @@ mod tests { generate_tests!(test_out_depends_on_internal_tensor); + generate_tests!(test_model_index_n_depends_on_model_index); + generate_tests!(test_model_index_n_dynamic_index_grad); + generate_tests!(test_model_index_n_dynamic_range_width_const); + + #[allow(dead_code)] + fn test_model_index_n_depends_on_model_index< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + // RED test for issue #112: + // - `N` is a reserved model index + // - `%` is supported in expressions + // - tensor indexing can use expression indices, e.g. amp_i[N % 2] + // + // Expected behavior once implemented: + // - N is taken from model_index. + // - model_index = 0 => N % 2 = 0 + // - model_index = 1 => N % 2 = 1 + let full_text = " + amp_i { 0, 10 } + dur_i { 10, 5 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[N % 2] - x, 1 } + stop_i { dur_i[N % 2] - tclock } + out_i { x, tclock } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut stop0 = vec![T::zero(); 1]; + let mut stop1 = vec![T::zero(); 1]; + let mut data = compiler.get_new_data(); + + compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + ); + compiler.calc_stop( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + stop0.as_mut_slice(), + ); + + compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + ); + compiler.calc_stop( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + stop1.as_mut_slice(), + ); + + assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); + assert_relative_eq!(u0[1], T::from_f64(0.0).unwrap()); + + assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); + assert_relative_eq!(rr0[1], T::from_f64(1.0).unwrap()); + assert_relative_eq!(stop0[0], T::from_f64(10.0).unwrap()); + + assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[1], T::from_f64(1.0).unwrap()); + assert_relative_eq!(stop1[0], T::from_f64(5.0).unwrap()); + + assert_ne!(rr0[0], rr1[0]); + assert_ne!(stop0[0], stop1[0]); + } + + #[allow(dead_code)] + fn test_model_index_n_dynamic_index_grad< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + let full_text = " + u_i { y = 3, z = 5 } + F_i { u_i[N % 2], 0 } + out_i { y, z } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut drr0 = vec![T::zero(); 2]; + let mut drr1 = vec![T::zero(); 2]; + let mut data = compiler.get_new_data(); + let mut ddata0 = compiler.get_new_data(); + let mut ddata1 = compiler.get_new_data(); + + compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + ); + + assert_relative_eq!(rr0[0], T::from_f64(3.0).unwrap()); + assert_relative_eq!(rr1[0], T::from_f64(5.0).unwrap()); + + let dyy0 = vec![T::one(), T::zero()]; + let dyy1 = vec![T::zero(), T::one()]; + compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.rhs_grad( + T::zero(), + u0.as_slice(), + dyy0.as_slice(), + data.as_slice(), + ddata0.as_mut_slice(), + rr0.as_slice(), + drr0.as_mut_slice(), + ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.rhs_grad( + T::zero(), + u0.as_slice(), + dyy1.as_slice(), + data.as_slice(), + ddata1.as_mut_slice(), + rr1.as_slice(), + drr1.as_mut_slice(), + ); + + assert_relative_eq!(drr0[0], T::one()); + assert_relative_eq!(drr0[1], T::zero()); + assert_relative_eq!(drr1[0], T::one()); + assert_relative_eq!(drr1[1], T::zero()); + } + + #[allow(dead_code)] + fn test_model_index_n_dynamic_range_width_const< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + let full_text = " + amp_i { 0, 10, 20 } + u_i { x = 1, y = 1 } + F_i { amp_i[N:N+2] - u_i } + out_i { x, y } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut data = compiler.get_new_data(); + + compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + ); + + assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); + assert_relative_eq!(rr0[1], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[1], T::from_f64(19.0).unwrap()); + } + #[allow(dead_code)] fn test_out_depends_on_internal_tensor< M: CodegenModuleCompile + CodegenModuleJit, @@ -1183,7 +1408,7 @@ mod tests { let mut results = Vec::new(); let inputs = vec![T::one(); n_inputs]; let mut out = vec![T::zero(); n_outputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), @@ -1225,6 +1450,7 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.set_u0_grad( u0.as_mut_slice(), @@ -1300,7 +1526,7 @@ mod tests { let mut ddata = compiler.get_new_data(); let mut dres = vec![T::zero(); n_states]; let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); compiler.rhs_sgrad( T::zero(), u0.as_slice(), @@ -1319,7 +1545,7 @@ mod tests { // forward mode sens (calc_out) let mut ddata = compiler.get_new_data(); let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); compiler.calc_out_sgrad( T::zero(), u0.as_slice(), @@ -1356,6 +1582,7 @@ mod tests { dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); results.push(dinputs.to_vec()); @@ -1379,6 +1606,7 @@ mod tests { dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); results.push(dinputs.to_vec()); } else { @@ -1604,6 +1832,8 @@ mod tests { abs_function: "r { abs(-2) }" expect "r" vec![f64::abs(-2.0)], min_function: "r { min(2, 3) }" expect "r" vec![2.0], max_function: "r { max(2, 3) }" expect "r" vec![3.0], + n_expression_scalar: "r { 2.0 * N + 1.0 }" expect "r" vec![1.0], + n_expression_vector: "r_i { 2.0 * N + 1.0, 3.0 * N + 2.0 }" expect "r" vec![1.0, 2.0], scalar: "r {2}" expect "r" vec![2.0,], constant: "r_i {2, 3}" expect "r" vec![2., 3.], derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.], @@ -1898,7 +2128,7 @@ mod tests { let mut ddata = compiler.get_new_data(); let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop, _has_mass) = compiler.get_dims(); let inputs = vec![T::from_f64(2.).unwrap(); n_inputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), @@ -1914,6 +2144,7 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.set_u0_grad( u0.as_mut_slice(), @@ -2042,7 +2273,7 @@ mod tests { let mut data = compiler.get_new_data(); let inputs = vec![1.1]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.1].as_slice()); @@ -2093,7 +2324,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2109,7 +2340,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2217,12 +2448,13 @@ mod tests { let mut ddata = compiler.get_new_data(); let a = vec![T::from_f64(0.6).unwrap()]; let da = vec![T::one()]; - compiler.set_inputs(a.as_slice(), data.as_mut_slice()); + compiler.set_inputs(a.as_slice(), data.as_mut_slice(), 0); compiler.set_inputs_grad( a.as_slice(), da.as_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); let mut u0 = vec![T::zero()]; let mut du0 = vec![T::zero()]; diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index 61bb1dc..b0d3f02 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -35,6 +35,7 @@ pub struct CraneliftModule { indices_id: DataId, constants_id: DataId, + model_index_id: DataId, thread_counter: Option, //triple: Triple, @@ -300,11 +301,24 @@ impl CraneliftModule { self.real_ptr_type, self.real_ptr_type, self.real_ptr_type, + self.int_type, ]; - let arg_names = &["inputs", "dinputs", "data", "ddata"]; + let arg_names = &["inputs", "dinputs", "data", "ddata", "model_index"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = codegen + .builder + .use_var(*codegen.variables.get("model_index").unwrap()); + codegen + .builder + .ins() + .store(codegen.mem_flags, model_index, model_index_ptr, 0); + let base_data_ptr = codegen.variables.get("ddata").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, true, false); @@ -426,6 +440,11 @@ impl CraneliftModule { let indices_id = module.declare_data("indices", Linkage::Local, false, false)?; module.define_data(indices_id, &data_description)?; + let mut data_description = DataDescription::new(); + data_description.define_zeroinit(int_type.bytes().try_into().unwrap()); + let model_index_id = module.declare_data("model_index", Linkage::Local, true, false)?; + module.define_data(model_index_id, &data_description)?; + let mut thread_counter = None; if threaded { let mut data_description = DataDescription::new(); @@ -442,6 +461,7 @@ impl CraneliftModule { module: Mutex::new(module), indices_id, constants_id, + model_index_id, int_type, real_type: real_type_cranelift, real_ptr_type: ptr_type, @@ -634,7 +654,6 @@ impl CraneliftModule { // F let res = *codegen.variables.get("rr").unwrap(); codegen.jit_compile_tensor(model.rhs(), Some(res), false)?; - codegen.builder.ins().return_(&[]); codegen.builder.finalize(); } @@ -784,11 +803,23 @@ impl CraneliftModule { } fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { - let arg_types = &[self.real_ptr_type, self.real_ptr_type]; - let arg_names = &["inputs", "data"]; + let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_type]; + let arg_names = &["inputs", "data", "model_index"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = codegen + .builder + .use_var(*codegen.variables.get("model_index").unwrap()); + codegen + .builder + .ins() + .store(codegen.mem_flags, model_index, model_index_ptr, 0); + let base_data_ptr = codegen.variables.get("data").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, false, false); @@ -1025,6 +1056,7 @@ struct CraneliftCodeGen<'a, M: Module> { layout: &'a DataLayout, indices: GlobalValue, constants: GlobalValue, + model_index_global: GlobalValue, threaded: bool, } @@ -1119,6 +1151,20 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { } AstKind::Number(value) => Ok(self.fconst(*value)), AstKind::Name(iname) => { + if iname.name == "N" { + if iname.is_tangent { + return Ok(self.fconst(0.0)); + } + let var = self + .variables + .get("model_index") + .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; + let model_index = self.builder.use_var(*var); + return Ok(self + .builder + .ins() + .fcvt_from_sint(self.real_type, model_index)); + } let ptr = if iname.is_tangent { // tangent of a constant is zero if self.layout.is_constant(iname.name) { @@ -1149,11 +1195,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { .position(|x| x == c) .unwrap_or(elmt.indices().len()); // if we are indexing, add the start indice to index[pi] - if let Some(indice) = - iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap()) - { - let start = indice.first.as_ref().kind.as_integer().unwrap(); - let start_intval = self.builder.ins().iconst(self.int_type, start); + if let Some(indice_ast) = iname.indice.as_ref() { + let Some(indice) = indice_ast.kind.as_indice() else { + return Err(anyhow!("invalid index expression '{}'", indice_ast)); + }; + let start_intval = + self.jit_compile_integer_expr(indice.first.as_ref())?; // if we are indexing a single element, the index may be out of bounds let index_pi = if pi >= index.len() { self.builder.ins().iconst(self.int_type, 0) @@ -1163,7 +1210,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { let index_pi = self.builder.ins().iadd(start_intval, index_pi); iname_index.push(index_pi); } else { - iname_index.push(index[pi]); + let index_pi = if pi >= index.len() { + self.builder.ins().iconst(self.int_type, 0) + } else { + index[pi] + }; + iname_index.push(index_pi); } no_transform = no_transform && pi == i; } @@ -1317,6 +1369,56 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { } } + fn jit_compile_integer_expr(&mut self, expr: &Ast) -> Result { + match &expr.kind { + AstKind::Integer(value) => Ok(self.builder.ins().iconst(self.int_type, *value)), + AstKind::Number(value) => { + if value.fract() != 0.0 { + return Err(anyhow!( + "non-integer value '{}' in integer expression", + value + )); + } + Ok(self.builder.ins().iconst(self.int_type, *value as i64)) + } + AstKind::Name(iname) => { + if iname.name == "N" { + let var = self + .variables + .get("model_index") + .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; + Ok(self.builder.use_var(*var)) + } else { + Err(anyhow!( + "unsupported name '{}' in integer expression", + iname.name + )) + } + } + AstKind::Monop(monop) => { + let child = self.jit_compile_integer_expr(monop.child.as_ref())?; + match monop.op { + '+' => Ok(child), + '-' => Ok(self.builder.ins().ineg(child)), + _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)), + } + } + AstKind::Binop(binop) => { + let lhs = self.jit_compile_integer_expr(binop.left.as_ref())?; + let rhs = self.jit_compile_integer_expr(binop.right.as_ref())?; + match binop.op { + '+' => Ok(self.builder.ins().iadd(lhs, rhs)), + '-' => Ok(self.builder.ins().isub(lhs, rhs)), + '*' => Ok(self.builder.ins().imul(lhs, rhs)), + '/' => Ok(self.builder.ins().sdiv(lhs, rhs)), + '%' => Ok(self.builder.ins().srem(lhs, rhs)), + _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)), + } + } + _ => Err(anyhow!("unsupported integer expression '{}'", expr)), + } + } + fn get_function_name(name: &str, is_tangent: bool) -> String { if is_tangent { format!("{name}__tangent__") @@ -2254,6 +2356,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { .unwrap() .declare_data_in_func(module.constants_id, builder.func); + let model_index_global = module + .module + .lock() + .unwrap() + .declare_data_in_func(module.model_index_id, builder.func); + // Create the entry block, to start emitting code in. let entry_block = builder.create_block(); @@ -2286,6 +2394,7 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { functions: HashMap::new(), layout: &module.layout, threaded: module.threaded, + model_index_global, }; // insert arg vars @@ -2294,6 +2403,19 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { codegen.declare_variable(*arg_type, arg_name, val); } + if !codegen.variables.contains_key("model_index") { + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = + codegen + .builder + .ins() + .load(codegen.int_type, codegen.mem_flags, model_index_ptr, 0); + codegen.declare_variable(codegen.int_type, "model_index", model_index); + } + // insert u if it exists in args if let Some(u) = codegen.variables.get("u") { let u_ptr = codegen.builder.use_var(*u); diff --git a/diffsl/src/execution/data_layout.rs b/diffsl/src/execution/data_layout.rs index 1df40f8..be470c4 100644 --- a/diffsl/src/execution/data_layout.rs +++ b/diffsl/src/execution/data_layout.rs @@ -44,6 +44,12 @@ impl DataLayout { // add layout info for "t" let t_layout = ArcLayout::new(Layout::new_scalar()); layout_map.insert("t".to_string(), t_layout); + is_constant_map.insert("t".to_string(), false); + + // add layout info for model index "N" + let n_layout = ArcLayout::new(Layout::new_scalar()); + layout_map.insert("N".to_string(), n_layout); + is_constant_map.insert("N".to_string(), false); let mut add_tensor = |tensor: &Tensor, in_data: bool, in_constants: bool| { // insert the data (non-zeros) for each tensor diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index e554bba..9a2ee0c 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -203,7 +203,7 @@ macro_rules! define_symbol_module { has_mass: *mut UIntType, ); #[link_name = "set_inputs"] - pub fn set_inputs(inputs: *const $ty, data: *mut $ty); + pub fn set_inputs(inputs: *const $ty, data: *mut $ty, model_index: UIntType); #[link_name = "get_inputs"] pub fn get_inputs(inputs: *mut $ty, data: *const $ty); #[link_name = "set_inputs_grad"] @@ -212,6 +212,7 @@ macro_rules! define_symbol_module { dinputs: *const $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, ); #[link_name = "set_inputs_rgrad"] pub fn set_inputs_rgrad( @@ -219,6 +220,7 @@ macro_rules! define_symbol_module { dinputs: *mut $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, ); } } diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index f9b1495..9ae9eeb 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -168,12 +168,23 @@ pub type GetDimsFunc = unsafe extern "C" fn( stop: *mut UIntType, has_mass: *mut UIntType, ); -pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const T, data: *mut T); +pub type SetInputsFunc = + unsafe extern "C" fn(inputs: *const T, data: *mut T, model_index: UIntType); pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut T, data: *const T); -pub type SetInputsGradFunc = - unsafe extern "C" fn(inputs: *const T, dinputs: *const T, data: *const T, ddata: *mut T); -pub type SetInputsRevGradFunc = - unsafe extern "C" fn(inputs: *const T, dinputs: *mut T, data: *const T, ddata: *mut T); +pub type SetInputsGradFunc = unsafe extern "C" fn( + inputs: *const T, + dinputs: *const T, + data: *const T, + ddata: *mut T, + model_index: UIntType, +); +pub type SetInputsRevGradFunc = unsafe extern "C" fn( + inputs: *const T, + dinputs: *mut T, + data: *const T, + ddata: *mut T, + model_index: UIntType, +); pub type SetIdFunc = unsafe extern "C" fn(id: *mut T); pub type GetTensorFunc = unsafe extern "C" fn(data: *const T, tensor_data: *mut *mut T, tensor_size: *mut UIntType); diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index eff7b1d..650f197 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -364,6 +364,7 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, ], CompileMode::Forward, "set_inputs_grad", @@ -427,6 +428,7 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, ], CompileMode::Reverse, "set_inputs_rgrad", @@ -547,6 +549,7 @@ struct Globals<'ctx> { indices: Option>, constants: Option>, thread_counter: Option>, + model_index: GlobalValue<'ctx>, } impl<'ctx> Globals<'ctx> { @@ -612,10 +615,19 @@ impl<'ctx> Globals<'ctx> { indices.set_initializer(&indices_value); Some(indices) }; + let model_index = module.add_global( + int_type, + Some(AddressSpace::default()), + "enzyme_const_model_index", + ); + model_index.set_visibility(GlobalVisibility::Hidden); + model_index.set_constant(false); + model_index.set_initializer(&int_type.const_zero()); Self { indices, thread_counter, constants, + model_index, } } } @@ -1303,6 +1315,7 @@ impl<'ctx> CodeGen<'ctx> { } fn insert_data(&mut self, model: &DiscreteModel) { + self.insert_model_index(); self.insert_constants(model); if let Some(input) = model.input() { @@ -1343,6 +1356,10 @@ impl<'ctx> CodeGen<'ctx> { } } + fn insert_model_index(&mut self) { + self.insert_param("model_index", self.globals.model_index.as_pointer_value()); + } + fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) { self.variables.insert(name.to_owned(), value); } @@ -2602,8 +2619,22 @@ impl<'ctx> CodeGen<'ctx> { } AstKind::Number(value) => Ok(self.real_type.const_float(*value)), AstKind::Name(iname) => { - let ptr = self.get_param(iname.name); - let layout = self.layout.get_layout(iname.name).unwrap(); + if iname.name == "N" { + if iname.is_tangent { + return Ok(self.real_type.const_float(0.0)); + } + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); + let n_value = self.builder.build_signed_int_to_float( + model_index, + self.real_type, + "n_as_real", + )?; + return Ok(n_value); + } + let ptr = *self.get_param(iname.name); + let layout = self.layout.get_layout(iname.name).unwrap().clone(); let iname_elmt_index = if layout.is_dense() { // permute indices based on the index chars of this tensor let mut no_transform = true; @@ -2617,14 +2648,12 @@ impl<'ctx> CodeGen<'ctx> { .position(|x| x == c) .unwrap_or(elmt.indices().len()); // if we are indexing, add the start indice to index[pi] - if let Some(indice) = - iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap()) - { - let start = indice.first.as_ref().kind.as_integer().unwrap(); - let start_intval = self - .context - .i32_type() - .const_int(start.try_into().unwrap(), false); + if let Some(indice_ast) = iname.indice.as_ref() { + let Some(indice) = indice_ast.kind.as_indice() else { + return Err(anyhow!("invalid index expression '{}'", indice_ast)); + }; + let start_intval = + self.jit_compile_integer_expr(indice.first.as_ref(), name)?; // if we are indexing a single element, the index may be out of bounds let index_pi = if pi >= index.len() { self.context.i32_type().const_int(0, false) @@ -2635,7 +2664,12 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_int_add(index_pi, start_intval, name)?; iname_index.push(index_pi); } else { - iname_index.push(index[pi]); + let index_pi = if pi >= index.len() { + self.context.i32_type().const_int(0, false) + } else { + index[pi] + }; + iname_index.push(index_pi); } no_transform = no_transform && pi == i; } @@ -2677,7 +2711,7 @@ impl<'ctx> CodeGen<'ctx> { } else if layout.is_sparse() || layout.is_diagonal() { let expr_layout = elmt.expr_layout(); - if expr_layout != layout { + if expr_layout != &layout { // get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ] // if its a -1 then return a 0 // ie. expr_index = binary_layout[expr_index] @@ -2685,10 +2719,10 @@ impl<'ctx> CodeGen<'ctx> { //. otherwise load the value at that index // we are doing an if statement so I think we need to return early here let permutation = - DataLayout::permutation(elmt, iname.indices.as_slice(), layout); + DataLayout::permutation(elmt, iname.indices.as_slice(), &layout); if let Some(base_binary_layout_index) = self.layout - .get_binary_layout_index(layout, expr_layout, permutation) + .get_binary_layout_index(&layout, expr_layout, permutation) { let binary_layout_index = self.builder.build_int_add( self.int_type @@ -2736,7 +2770,7 @@ impl<'ctx> CodeGen<'ctx> { let value_ptr = Self::get_ptr_to_index( &self.builder, self.real_type, - ptr, + &ptr, mapped_index, name, ); @@ -2767,7 +2801,7 @@ impl<'ctx> CodeGen<'ctx> { let value_ptr = Self::get_ptr_to_index( &self.builder, self.real_type, - ptr, + &ptr, iname_elmt_index, name, ); @@ -2789,6 +2823,69 @@ impl<'ctx> CodeGen<'ctx> { } } + fn jit_compile_integer_expr(&mut self, expr: &Ast, name: &str) -> Result> { + match &expr.kind { + AstKind::Integer(value) => Ok(self.int_type.const_int(*value as u64, true)), + AstKind::Number(value) => { + if value.fract() != 0.0 { + return Err(anyhow!( + "non-integer value '{}' in integer expression", + value + )); + } + Ok(self.int_type.const_int(*value as u64, true)) + } + AstKind::Name(iname) => { + if iname.name == "N" { + Ok(self + .build_load(self.int_type, *self.get_param("model_index"), name)? + .into_int_value()) + } else { + Err(anyhow!( + "unsupported name '{}' in integer expression", + iname.name + )) + } + } + AstKind::Monop(monop) => { + let child = self.jit_compile_integer_expr(monop.child.as_ref(), name)?; + match monop.op { + '+' => Ok(child), + '-' => self.builder.build_int_neg(child, name).map_err(Into::into), + _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)), + } + } + AstKind::Binop(binop) => { + let lhs = self.jit_compile_integer_expr(binop.left.as_ref(), name)?; + let rhs = self.jit_compile_integer_expr(binop.right.as_ref(), name)?; + match binop.op { + '+' => self + .builder + .build_int_add(lhs, rhs, name) + .map_err(Into::into), + '-' => self + .builder + .build_int_sub(lhs, rhs, name) + .map_err(Into::into), + '*' => self + .builder + .build_int_mul(lhs, rhs, name) + .map_err(Into::into), + '/' => self + .builder + .build_int_signed_div(lhs, rhs, name) + .map_err(Into::into), + '%' => self + .builder + .build_int_signed_rem(lhs, rhs, name) + .map_err(Into::into), + _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)), + } + } + _ => Err(anyhow!("unsupported integer expression '{}'", expr)), + } + } + fn clear(&mut self) { self.variables.clear(); //self.functions.clear(); @@ -2993,7 +3090,7 @@ impl<'ctx> CodeGen<'ctx> { // add noalias let alias_id = Attribute::get_named_enum_kind_id("noalias"); let noalign = self.context.create_enum_attribute(alias_id, 0); - for i in &[1, 2] { + for i in &[1, 2, 3] { function.add_attribute(AttributeLoc::Param(*i), noalign); } @@ -3319,8 +3416,8 @@ impl<'ctx> CodeGen<'ctx> { // F let res_ptr = self.get_param("rr"); self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?; - let barrier_num = self.int_type.const_int(nbarriers + 1, false); let total_barriers_val = self.int_type.const_int(total_barriers, false); + let barrier_num = self.int_type.const_int(nbarriers + 1, false); self.jit_compile_call_barrier(barrier_num, total_barriers_val); self.builder.build_return(None)?; @@ -3439,7 +3536,8 @@ impl<'ctx> CodeGen<'ctx> { let mut start_param_index: Vec = Vec::new(); let mut ptr_arg_indices: Vec = Vec::new(); for (i, arg) in original_function.get_param_iter().enumerate() { - start_param_index.push(u32::try_from(fn_type.len()).unwrap()); + let start_index = u32::try_from(fn_type.len()).unwrap(); + start_param_index.push(start_index); let arg_type = arg.get_type(); fn_type.push(arg_type.into()); @@ -3448,13 +3546,16 @@ impl<'ctx> CodeGen<'ctx> { enzyme_fn_type.push(arg.get_type().into()); if arg_type.is_pointer_type() { - ptr_arg_indices.push(u32::try_from(i).unwrap()); + ptr_arg_indices.push(start_index); } match args_type[i] { CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => { fn_type.push(arg.get_type().into()); enzyme_fn_type.push(arg.get_type().into()); + if arg_type.is_pointer_type() { + ptr_arg_indices.push(start_index + 1); + } } CompileGradientArgType::Const => {} } @@ -3868,14 +3969,21 @@ impl<'ctx> CodeGen<'ctx> { ) -> Result> { self.clear(); let function_name = if is_get { "get_inputs" } else { "set_inputs" }; - let fn_arg_names = &["inputs", "data"]; - let function = self.add_function( - function_name, - fn_arg_names, - &[self.real_ptr_type.into(), self.real_ptr_type.into()], - None, - false, - ); + let fn_arg_names: &[&str] = if is_get { + &["inputs", "data"] + } else { + &["inputs", "data", "model_index"] + }; + let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get { + &[self.real_ptr_type.into(), self.real_ptr_type.into()] + } else { + &[ + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.int_type.into(), + ] + }; + let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false); let block = self.start_function(function, None); for (i, arg) in function.get_param_iter().enumerate() { @@ -3884,6 +3992,14 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } + if !is_get { + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); + self.builder + .build_store(self.globals.model_index.as_pointer_value(), model_index)?; + } + if let Some(input) = model.input() { let name = input.name(); self.insert_tensor(input, false); diff --git a/diffsl/src/parser/ds_grammar.pest b/diffsl/src/parser/ds_grammar.pest index c6bf354..edf66e5 100644 --- a/diffsl/src/parser/ds_grammar.pest +++ b/diffsl/src/parser/ds_grammar.pest @@ -9,14 +9,20 @@ assignment = { name ~ "=" ~ expression } expression = { term ~ (term_op ~ term)* } term = { factor ~ (factor_op ~ factor)* } factor = { sign? ~ ( call | real | integer | name_ij_index | name_ij | "(" ~ expression ~ ")" ) } +integer_expression = { integer_term ~ (term_op ~ integer_term)* } +integer_term = { integer_factor ~ (integer_factor_op ~ integer_factor)* } +integer_factor = { sign? ~ ( integer | integer_name | "(" ~ integer_expression ~ ")" ) } call = { name ~ "(" ~ call_arg ~ ("," ~ call_arg )* ~ ")" } call_arg = { expression } name_ij = ${ name ~ ("_" ~ name)? } -name_ij_index = ${ name_ij ~ "[" ~ indice ~ "]" } +index_indice = { integer_expression ~ ( range_sep ~ integer_expression )? } +name_ij_index = { name_ij ~ "[" ~ index_indice ~ "]" } range_sep = @{ ".." | ":" } sign = @{ ("-"|"+") } term_op = @{ "-"|"+" } factor_op = @{ "*"|"/" } +integer_factor_op = @{ "*"|"/"|"%" } +integer_name = @{ "N" } name = @{ ( 'a'..'z' | 'A'..'Z' ) ~ ('a'..'z' | 'A'..'Z' | '0'..'9' )* } integer = @{ ('0'..'9')+ } real = @{ ('0'..'9')+ ~ ( "." ~ ('0'..'9')* )? ~ ( "e" ~ sign? ~ integer )? } @@ -26,4 +32,3 @@ COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" | "//" ~ (!NEWLINE ~ ANY)* } - diff --git a/diffsl/src/parser/ds_parser.rs b/diffsl/src/parser/ds_parser.rs index 8f9c710..6a228ba 100644 --- a/diffsl/src/parser/ds_parser.rs +++ b/diffsl/src/parser/ds_parser.rs @@ -37,8 +37,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { pos_end: pair.as_span().end(), }); match pair.as_rule() { - // name = @{ 'a'..'z' ~ ("_" | 'a'..'z' | 'A'..'Z' | '0'..'9')* } - Rule::name => Ast { + // name = @{ ... } + // integer_name = @{ "N" } + Rule::name | Rule::integer_name => Ast { kind: AstKind::Name(ast::Name { name: pair.as_str(), indice: None, @@ -85,12 +86,13 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } } - //expression = { term ~ (term_op ~ term)* } - Rule::expression => { + // expression = { term ~ (term_op ~ term)* } + // integer_expression = { integer_term ~ (term_op ~ integer_term)* } + Rule::expression | Rule::integer_expression => { let mut inner = pair.into_inner(); let mut head_term = parse_value(inner.next().unwrap()); while inner.peek().is_some() { - //term_op = @{ "-"|"+" } + // term_op = @{ "-"|"+" } let term_op = parse_sign(inner.next().unwrap()); let rhs_term = parse_value(inner.next().unwrap()); let subspan = Some(StringSpan { @@ -109,8 +111,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { head_term } - //term = { factor ~ (factor_op ~ factor)* } - Rule::term => { + // term = { factor ~ (factor_op ~ factor)* } + // integer_term = { integer_factor ~ (integer_factor_op ~ integer_factor)* } + Rule::term | Rule::integer_term => { let mut inner = pair.into_inner(); let mut head_factor = parse_value(inner.next().unwrap()); while inner.peek().is_some() { @@ -132,8 +135,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { head_factor } - // factor = { sign? ~ (call | name | real | integer | "(" ~ expression ~ ")" ) } - Rule::factor => { + // factor = { sign? ~ (call | real | integer | name_ij_index | name_ij | "(" ~ expression ~ ")" ) } + // integer_factor = { sign? ~ ( integer | name | "(" ~ integer_expression ~ ")" ) } + Rule::factor | Rule::integer_factor => { let mut inner = pair.into_inner(); let sign = if inner.peek().unwrap().as_rule() == Rule::sign { Some(parse_sign(inner.next().unwrap())) @@ -152,7 +156,7 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } } - // name_ij_index = ${ name_ij ~ "[" ~ indice ~ "]" } + // name_ij_index = ${ name_ij ~ "[" ~ index_indice ~ "]" } Rule::name_ij_index => { let mut inner = pair.into_inner(); let mut name_ij = parse_value(inner.next().unwrap()); @@ -232,7 +236,8 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } // indice = { integer ~ ( range_sep ~ integer )? } - Rule::indice => { + // index_indice = { integer_expression ~ ( range_sep ~ integer_expression )? } + Rule::indice | Rule::index_indice => { let mut inner = pair.into_inner(); let first = Box::new(parse_value(inner.next().unwrap())); if inner.peek().is_some() { @@ -488,4 +493,64 @@ mod tests { assert_eq!(assignment.name, "y"); assert_eq!(assignment.expr.to_string(), "3"); } + + #[test] + fn index_expression_with_modulo() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[((N + 4 / 2) * 3 - 1) % 2] - x, 1 } + stop_i { 10 - tclock } + "; + let model = parse_string(TEXT).unwrap(); + assert_eq!(model.tensors.len(), 4); + let f_tensor = model.tensors[2].kind.as_tensor().unwrap(); + let f_expr = f_tensor.elmts()[0] + .kind + .as_tensor_elmt() + .unwrap() + .expr + .to_string(); + assert!(f_expr.contains("%")); + } + + #[test] + fn float_not_allowed_in_integer_expression_index() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[N % 1.1] - x, 1 } + stop_i { 10 - tclock } + "; + assert!(parse_string(TEXT).is_err()); + } + + #[test] + fn non_n_name_not_allowed_in_integer_expression_index() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[x % 2] - x, 1 } + stop_i { 10 - tclock } + "; + assert!(parse_string(TEXT).is_err()); + } + + #[test] + fn slice_index_still_parses() { + const TEXT: &str = " + a_i { 0.0, 1.0, 2.0, 3.0 } + r_i { a_i[1:3] } + "; + let model = parse_string(TEXT).unwrap(); + assert_eq!(model.tensors.len(), 2); + let r_tensor = model.tensors[1].kind.as_tensor().unwrap(); + let expr = r_tensor.elmts()[0] + .kind + .as_tensor_elmt() + .unwrap() + .expr + .to_string(); + assert!(expr.contains("[1:3]")); + } } diff --git a/diffsl/tests/pybamm_dfn.rs b/diffsl/tests/pybamm_dfn.rs index 1d57331..94e62a3 100644 --- a/diffsl/tests/pybamm_dfn.rs +++ b/diffsl/tests/pybamm_dfn.rs @@ -34,7 +34,7 @@ fn test_dfn_model_initialization() { let mut data = compiler.get_new_data(); let (n_states, n_inputs, _, _, _, _) = compiler.get_dims(); let inputs = vec![1.0; n_inputs]; - compiler.set_inputs(&inputs, &mut data); + compiler.set_inputs(&inputs, &mut data, 0); let mut u = vec![0.0; n_states]; compiler.set_u0(&mut u, &mut data); let mut rr = vec![0.0; n_states]; diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 0807ed9..0054f02 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -333,7 +333,7 @@ macro_rules! define_external_test { } #[no_mangle] - pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty) { + pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty, _model_index: u32) { if inputs.is_null() || data.is_null() { return; } @@ -354,6 +354,7 @@ macro_rules! define_external_test { dinputs: *const $ty, _data: *const $ty, ddata: *mut $ty, + _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -367,6 +368,7 @@ macro_rules! define_external_test { dinputs: *mut $ty, _data: *const $ty, ddata: *mut $ty, + _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -390,7 +392,7 @@ macro_rules! define_external_test { let mut data = vec![-1.0 as $ty; n_data]; let inputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs(&inputs, &mut data); + compiler.set_inputs(&inputs, &mut data, 0); let mut inputs_out = vec![-2.0 as $ty; n_inputs]; compiler.get_inputs(&mut inputs_out, &data); @@ -433,21 +435,13 @@ macro_rules! define_external_test { assert_eq!(ddata[0], 0.0 as $ty); let mut dinputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata); + compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata, 0); assert_eq!(ddata[0], 1.0 as $ty); let mut du_rev = vec![-11.0 as $ty; n_states]; let mut ddata_rev = vec![-12.0 as $ty; n_data]; let mut drr_rev = vec![1.0 as $ty; n_states]; - compiler.rhs_rgrad( - 0.0 as $ty, - &u, - &mut du_rev, - &data, - &mut ddata_rev, - &rr, - &mut drr_rev, - ); + compiler.rhs_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &rr, &mut drr_rev); assert_eq!(du_rev[0], -12.0 as $ty); assert_eq!(ddata_rev[0], -12.0 as $ty); @@ -457,18 +451,10 @@ macro_rules! define_external_test { assert_eq!(dv[0], -12.0 as $ty); let mut dout_rev = vec![1.0 as $ty; n_outputs]; - compiler.calc_out_rgrad( - 0.0 as $ty, - &u, - &mut du_rev, - &data, - &mut ddata_rev, - &out, - &mut dout_rev, - ); + compiler.calc_out_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &out, &mut dout_rev); assert_eq!(du_rev[0], -11.0 as $ty); - compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev); + compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev, 0); assert_eq!(dinputs[0], -11.0 as $ty); let mut ddata_s = vec![-14.0 as $ty; n_data];