From 1db45ac72fc71bae40f8f184911bf1bfba76a42d Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 8 Mar 2026 00:54:57 +0000 Subject: [PATCH 1/3] feat: add integer expressions for indexing --- diffsl/benches/evaluation.rs | 2 +- diffsl/benches/pybamm_dfn.rs | 3 +- diffsl/src/ast/mod.rs | 17 +- diffsl/src/discretise/env.rs | 190 +++++++++++- diffsl/src/execution/compiler.rs | 296 ++++++++++++++++++- diffsl/src/execution/cranelift/codegen.rs | 106 ++++++- diffsl/src/execution/external/mod.rs | 11 + diffsl/src/execution/interface.rs | 11 + diffsl/src/execution/llvm/codegen.rs | 157 ++++++++-- diffsl/src/parser/ds_grammar.pest | 9 +- diffsl/src/parser/ds_parser.rs | 87 +++++- diffsl/tests/pybamm_dfn.rs | 4 +- diffsl/tests/support/external_test_macros.rs | 49 ++- 13 files changed, 841 insertions(+), 101 deletions(-) diff --git a/diffsl/benches/evaluation.rs b/diffsl/benches/evaluation.rs index d42bf78..e38d848 100644 --- a/diffsl/benches/evaluation.rs +++ b/diffsl/benches/evaluation.rs @@ -62,7 +62,7 @@ fn execute( let t = 0.0; bencher.bench_local(|| { - compiler.rhs(t, &u, &mut data, &mut rr); + compiler.rhs(t, &u, &mut data, &mut rr, 0); }); } diff --git a/diffsl/benches/pybamm_dfn.rs b/diffsl/benches/pybamm_dfn.rs index 7dd9af6..e7cd228 100644 --- a/diffsl/benches/pybamm_dfn.rs +++ b/diffsl/benches/pybamm_dfn.rs @@ -35,6 +35,7 @@ fn pybamm_dfn_execute_rhs_grad(bench ddata.as_mut_slice(), rr.as_mut_slice(), drr.as_mut_slice(), + 0, ); }); } @@ -67,7 +68,7 @@ fn pybamm_dfn_execute_rhs(bencher: B let mut data = compiler.get_new_data(); let mut rr = vec![0.0; n_states]; bencher.bench_local(move || { - compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); }); } diff --git a/diffsl/src/ast/mod.rs b/diffsl/src/ast/mod.rs index 5ea2433..667fd7f 100644 --- a/diffsl/src/ast/mod.rs +++ b/diffsl/src/ast/mod.rs @@ -583,7 +583,15 @@ impl<'a> Ast<'a> { if name.name == "t" { Self::new_number(0.0) } else { - Self::new_name(name.name, name.indices.clone(), true) + Ast { + kind: AstKind::Name(Name { + name: name.name, + indices: name.indices.clone(), + indice: name.indice.clone(), + is_tangent: true, + }), + span: self.span, + } } } AstKind::Number(_) => Self::new_number(0.0), @@ -883,9 +891,10 @@ impl<'a> Ast<'a> { AstKind::Name(found_name) => { // if the name is indexed by a single indice, don't add that index to the list if let Some(indice) = &found_name.indice { - let indice = indice.as_ref().kind.as_indice().unwrap(); - if indice.sep.is_none() || indice.last.is_none() { - return; + if let Some(indice) = indice.as_ref().kind.as_indice() { + if indice.sep.is_none() || indice.last.is_none() { + return; + } } } indices.extend(found_name.indices.iter().cloned()); diff --git a/diffsl/src/discretise/env.rs b/diffsl/src/discretise/env.rs index 46b5ebe..f0d8500 100644 --- a/diffsl/src/discretise/env.rs +++ b/diffsl/src/discretise/env.rs @@ -57,6 +57,127 @@ pub struct Env { } impl Env { + fn eval_const_integer_expr(expr: &Ast) -> Option { + match &expr.kind { + AstKind::Integer(v) => Some(*v), + AstKind::Number(v) => { + if v.fract() == 0.0 { + Some(*v as i64) + } else { + None + } + } + AstKind::Monop(op) => { + let child = Self::eval_const_integer_expr(op.child.as_ref())?; + match op.op { + '+' => Some(child), + '-' => child.checked_neg(), + _ => None, + } + } + AstKind::Binop(op) => { + let left = Self::eval_const_integer_expr(op.left.as_ref())?; + let right = Self::eval_const_integer_expr(op.right.as_ref())?; + match op.op { + '+' => left.checked_add(right), + '-' => left.checked_sub(right), + '*' => left.checked_mul(right), + '/' => { + if right == 0 { + None + } else { + Some(left / right) + } + } + '%' => { + if right == 0 { + None + } else { + Some(left % right) + } + } + _ => None, + } + } + _ => None, + } + } + + fn eval_integer_expr_with_n(expr: &Ast, n: i64) -> Option { + match &expr.kind { + AstKind::Integer(v) => Some(*v), + AstKind::Number(v) => { + if v.fract() == 0.0 { + Some(*v as i64) + } else { + None + } + } + AstKind::Name(name) => { + if name.name == "N" { + Some(n) + } else { + None + } + } + AstKind::Monop(op) => { + let child = Self::eval_integer_expr_with_n(op.child.as_ref(), n)?; + match op.op { + '+' => Some(child), + '-' => child.checked_neg(), + _ => None, + } + } + AstKind::Binop(op) => { + let left = Self::eval_integer_expr_with_n(op.left.as_ref(), n)?; + let right = Self::eval_integer_expr_with_n(op.right.as_ref(), n)?; + match op.op { + '+' => left.checked_add(right), + '-' => left.checked_sub(right), + '*' => left.checked_mul(right), + '/' => { + if right == 0 { + None + } else { + Some(left / right) + } + } + '%' => { + if right == 0 { + None + } else { + Some(left % right) + } + } + _ => None, + } + } + _ => None, + } + } + + fn eval_constant_range_width(start: &Ast, end: &Ast) -> Option { + if let (Some(first), Some(last)) = ( + Self::eval_const_integer_expr(start), + Self::eval_const_integer_expr(end), + ) { + return last.checked_sub(first); + } + + let mut width = None; + for n in [0_i64, 1, 2, 3, 7, 16] { + let first_n = Self::eval_integer_expr_with_n(start, n)?; + let last_n = Self::eval_integer_expr_with_n(end, n)?; + let width_n = last_n.checked_sub(first_n)?; + match width { + Some(prev) if prev != width_n => return None, + Some(_) => {} + None => width = Some(width_n), + } + } + width + } + pub fn new() -> Self { let mut vars = HashMap::new(); vars.insert( @@ -280,32 +401,79 @@ impl Env { // if the indice is a single integer then the resulting layout is a scalar if indice.sep.is_none() { let mut new_layout = Layout::new_scalar(); - let first = indice.first.kind.as_integer().unwrap(); - let last = first + 1; + let (first, last) = + if let Some(first) = Self::eval_const_integer_expr(indice.first.as_ref()) { + (first, first + 1) + } else { + // Dynamic integer indices (e.g. involving N) cannot be resolved at compile time. + // Conservatively keep dependencies for the full 1D extent. + let axis = layout_permuted + .shape() + .iter() + .position(|&d| d != 1) + .unwrap_or(0); + let dim = *layout_permuted.shape().get(axis).unwrap_or(&1); + (0, i64::try_from(dim).unwrap()) + }; new_layout.filter_deps_from(layout_permuted, first, last); return Some(new_layout); } else { // if the indice is a range then the resulting layout is a dense layout with shape given by the range // along the only non-unit dimension of the variable - let first = indice.first.kind.as_integer().unwrap(); - let last = indice.last.as_ref().unwrap().kind.as_integer().unwrap(); + let end_expr = indice.last.as_ref().unwrap().as_ref(); + let Some(width) = Self::eval_constant_range_width(indice.first.as_ref(), end_expr) + else { + self.errs.push(ValidationError::new( + "range indice width must be an integer constant (independent of N)" + .to_string(), + ast.span, + )); + return None; + }; // make sure the range is valid - if last < first { + if width < 0 { self.errs.push(ValidationError::new( - format!( - "invalid range indice: start {} is greater than end {}", - first, last - ), + format!("invalid range indice: width {} is negative", width), ast.span, )); return None; } - let dim = usize::try_from(last - first).unwrap(); + let dim = usize::try_from(width).unwrap(); let shape = layout_permuted .shape() .map(|&d| if d != 1 { dim } else { 1 }); let mut new_layout = Layout::new_dense(Shape::from(shape)); - new_layout.filter_deps_from(layout_permuted, first, last); + let first_const = Self::eval_const_integer_expr(indice.first.as_ref()); + if let Some(first) = first_const { + let last = first + width; + new_layout.filter_deps_from(layout_permuted, first, last); + } else if dim != 0 { + // For dynamic starts (e.g. N:N+2), union dependencies over all valid windows. + let axis = layout_permuted + .shape() + .iter() + .position(|&d| d != 1) + .unwrap_or(0); + let source_dim = + i64::try_from(*layout_permuted.shape().get(axis).unwrap_or(&1)).unwrap(); + let max_start = source_dim.saturating_sub(width); + let mut merged_layout: Option = None; + for start in 0..=max_start { + let mut window_layout = Layout::new_dense(new_layout.shape().clone()); + window_layout.filter_deps_from( + layout_permuted.clone(), + start, + start + width, + ); + merged_layout = Some(match merged_layout { + Some(accumulated) => accumulated.union(window_layout), + None => window_layout, + }); + } + if let Some(merged_layout) = merged_layout { + new_layout = merged_layout; + } + } return Some(new_layout); } } diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index 32c5461..e66a342 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -410,7 +410,7 @@ impl Compiler { }) } - pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T]) { + pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T], model_index: u32) { if self.number_of_stop == 0 { panic!("Model does not have a stop function"); } @@ -423,13 +423,14 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, stop.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) { + pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_data_len(data, "data"); @@ -439,6 +440,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, rr.as_ptr() as *mut T, + model_index, i, dim, ) @@ -486,6 +488,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -502,6 +505,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) @@ -518,6 +522,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -538,6 +543,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) @@ -568,7 +574,16 @@ impl Compiler { }); } - pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { + pub fn rhs_sgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + rr: &[T], + drr: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -586,13 +601,23 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { + pub fn rhs_srgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + rr: &[T], + drr: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -610,13 +635,14 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T]) { + pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); self.check_out_len(out, "out"); @@ -626,6 +652,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, out.as_ptr() as *mut T, + model_index, i, dim, ) @@ -642,6 +669,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -658,6 +686,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -674,6 +703,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -694,6 +724,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -708,6 +739,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -726,6 +758,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -740,6 +773,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -758,6 +792,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -994,6 +1029,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); } @@ -1041,12 +1077,14 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); assert_eq!(stop.len(), 1); @@ -1054,6 +1092,217 @@ mod tests { generate_tests!(test_out_depends_on_internal_tensor); + generate_tests!(test_model_index_n_depends_on_model_index); + generate_tests!(test_model_index_n_dynamic_index_grad); + generate_tests!(test_model_index_n_dynamic_range_width_const); + + #[allow(dead_code)] + fn test_model_index_n_depends_on_model_index< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + // RED test for issue #112: + // - `N` is a reserved model index + // - `%` is supported in expressions + // - tensor indexing can use expression indices, e.g. amp_i[N % 2] + // + // Expected behavior once implemented: + // - N is taken from model_index. + // - model_index = 0 => N % 2 = 0 + // - model_index = 1 => N % 2 = 1 + let full_text = " + amp_i { 0, 10 } + dur_i { 10, 5 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[N % 2] - x, 1 } + stop_i { dur_i[N % 2] - tclock } + out_i { x, tclock } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut stop0 = vec![T::zero(); 1]; + let mut stop1 = vec![T::zero(); 1]; + let mut data = compiler.get_new_data(); + + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + 0, + ); + compiler.calc_stop( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + stop0.as_mut_slice(), + 0, + ); + + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + 1, + ); + compiler.calc_stop( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + stop1.as_mut_slice(), + 1, + ); + + assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); + assert_relative_eq!(u0[1], T::from_f64(0.0).unwrap()); + + assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); + assert_relative_eq!(rr0[1], T::from_f64(1.0).unwrap()); + assert_relative_eq!(stop0[0], T::from_f64(10.0).unwrap()); + + assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[1], T::from_f64(1.0).unwrap()); + assert_relative_eq!(stop1[0], T::from_f64(5.0).unwrap()); + + assert_ne!(rr0[0], rr1[0]); + assert_ne!(stop0[0], stop1[0]); + } + + #[allow(dead_code)] + fn test_model_index_n_dynamic_index_grad< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + let full_text = " + u_i { y = 3, z = 5 } + F_i { u_i[N % 2], 0 } + out_i { y, z } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut drr0 = vec![T::zero(); 2]; + let mut drr1 = vec![T::zero(); 2]; + let mut data = compiler.get_new_data(); + let mut ddata0 = compiler.get_new_data(); + let mut ddata1 = compiler.get_new_data(); + + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + 0, + ); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + 1, + ); + + assert_relative_eq!(rr0[0], T::from_f64(3.0).unwrap()); + assert_relative_eq!(rr1[0], T::from_f64(5.0).unwrap()); + + let dyy0 = vec![T::one(), T::zero()]; + let dyy1 = vec![T::zero(), T::one()]; + compiler.rhs_grad( + T::zero(), + u0.as_slice(), + dyy0.as_slice(), + data.as_slice(), + ddata0.as_mut_slice(), + rr0.as_slice(), + drr0.as_mut_slice(), + 0, + ); + compiler.rhs_grad( + T::zero(), + u0.as_slice(), + dyy1.as_slice(), + data.as_slice(), + ddata1.as_mut_slice(), + rr1.as_slice(), + drr1.as_mut_slice(), + 1, + ); + + assert_relative_eq!(drr0[0], T::one()); + assert_relative_eq!(drr0[1], T::zero()); + assert_relative_eq!(drr1[0], T::one()); + assert_relative_eq!(drr1[1], T::zero()); + } + + #[allow(dead_code)] + fn test_model_index_n_dynamic_range_width_const< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + let full_text = " + amp_i { 0, 10, 20 } + u_i { x = 1, y = 1 } + F_i { amp_i[N:N+2] - u_i } + out_i { x, y } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(); 2]; + let mut rr0 = vec![T::zero(); 2]; + let mut rr1 = vec![T::zero(); 2]; + let mut data = compiler.get_new_data(); + + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr0.as_mut_slice(), + 0, + ); + compiler.rhs( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + rr1.as_mut_slice(), + 1, + ); + + assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); + assert_relative_eq!(rr0[1], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap()); + assert_relative_eq!(rr1[1], T::from_f64(19.0).unwrap()); + } + #[allow(dead_code)] fn test_out_depends_on_internal_tensor< M: CodegenModuleCompile + CodegenModuleJit, @@ -1084,6 +1333,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); assert_relative_eq!(out[0], T::from_f64(2.).unwrap()); u0[0] = T::from_f64(2.).unwrap(); @@ -1092,6 +1342,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); assert_relative_eq!(out[0], T::from_f64(4.).unwrap()); let mut stop = vec![T::zero()]; @@ -1100,6 +1351,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(3.5).unwrap()); u0[0] = T::from_f64(0.5).unwrap(); @@ -1108,6 +1360,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); } @@ -1190,12 +1443,14 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); compiler.calc_out( T::zero(), u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); let (tensor_len, tensor_is_constant) = if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, data.as_slice()) { @@ -1240,6 +1495,7 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), + 0, ); compiler.calc_out_grad( T::zero(), @@ -1249,6 +1505,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, ddata.as_slice()) { results.push(tensor_data.to_vec()); @@ -1277,6 +1534,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); compiler.rhs_rgrad( T::zero(), @@ -1286,6 +1544,7 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); compiler.set_u0_rgrad( u0.as_mut_slice(), @@ -1308,6 +1567,7 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); results.push( compiler @@ -1327,6 +1587,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); results.push( compiler @@ -1350,6 +1611,7 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), @@ -1373,6 +1635,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), @@ -1905,6 +2168,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); for _i in 0..3 { @@ -1929,6 +2193,7 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), + 0, ); assert_relative_eq!(dres.as_slice(), vec![T::from_f64(8.).unwrap()].as_slice()); } @@ -1990,9 +2255,21 @@ mod tests { let mut data = compiler.get_new_data(); let (_n_states, _n_inputs, _n_outputs, _n_data, _n_stop, _has_mass) = compiler.get_dims(); compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); - compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.rhs( + 0.0, + u.as_slice(), + data.as_mut_slice(), + res.as_mut_slice(), + 0, + ); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); - compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.rhs( + 0.0, + u.as_slice(), + data.as_mut_slice(), + res.as_mut_slice(), + 0, + ); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); } @@ -2056,7 +2333,7 @@ mod tests { assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice()); let mut rr = vec![1., 1.]; - compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice()); let up = vec![2., 3.]; @@ -2065,7 +2342,7 @@ mod tests { assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice()); let mut out = vec![0.; 3]; - compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice()); + compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice(), 0); assert_relative_eq!(out.as_slice(), vec![1., 2., 4.].as_slice()); } } @@ -2184,6 +2461,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); }); diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index 61bb1dc..fd38244 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -214,6 +214,7 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; let arg_names = &[ "t", @@ -223,6 +224,7 @@ impl CraneliftModule { "ddata", "out", "dout", + "model_index", "threadId", "threadDim", ]; @@ -250,6 +252,7 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; let arg_names = &[ "t", @@ -259,6 +262,7 @@ impl CraneliftModule { "ddata", "rr", "drr", + "model_index", "threadId", "threadDim", ]; @@ -525,8 +529,17 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, + ]; + let arg_names = &[ + "t", + "u", + "data", + "out", + "model_index", + "threadId", + "threadDim", ]; - let arg_names = &["t", "u", "data", "out", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -568,8 +581,17 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, + ]; + let arg_names = &[ + "t", + "u", + "data", + "root", + "model_index", + "threadId", + "threadDim", ]; - let arg_names = &["t", "u", "data", "root", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -611,8 +633,17 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, + ]; + let arg_names = &[ + "t", + "u", + "data", + "rr", + "model_index", + "threadId", + "threadDim", ]; - let arg_names = &["t", "u", "data", "rr", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -634,7 +665,6 @@ impl CraneliftModule { // F let res = *codegen.variables.get("rr").unwrap(); codegen.jit_compile_tensor(model.rhs(), Some(res), false)?; - codegen.builder.ins().return_(&[]); codegen.builder.finalize(); } @@ -1149,11 +1179,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { .position(|x| x == c) .unwrap_or(elmt.indices().len()); // if we are indexing, add the start indice to index[pi] - if let Some(indice) = - iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap()) - { - let start = indice.first.as_ref().kind.as_integer().unwrap(); - let start_intval = self.builder.ins().iconst(self.int_type, start); + if let Some(indice_ast) = iname.indice.as_ref() { + let Some(indice) = indice_ast.kind.as_indice() else { + return Err(anyhow!("invalid index expression '{}'", indice_ast)); + }; + let start_intval = + self.jit_compile_integer_expr(indice.first.as_ref())?; // if we are indexing a single element, the index may be out of bounds let index_pi = if pi >= index.len() { self.builder.ins().iconst(self.int_type, 0) @@ -1163,7 +1194,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { let index_pi = self.builder.ins().iadd(start_intval, index_pi); iname_index.push(index_pi); } else { - iname_index.push(index[pi]); + let index_pi = if pi >= index.len() { + self.builder.ins().iconst(self.int_type, 0) + } else { + index[pi] + }; + iname_index.push(index_pi); } no_transform = no_transform && pi == i; } @@ -1317,6 +1353,56 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { } } + fn jit_compile_integer_expr(&mut self, expr: &Ast) -> Result { + match &expr.kind { + AstKind::Integer(value) => Ok(self.builder.ins().iconst(self.int_type, *value)), + AstKind::Number(value) => { + if value.fract() != 0.0 { + return Err(anyhow!( + "non-integer value '{}' in integer expression", + value + )); + } + Ok(self.builder.ins().iconst(self.int_type, *value as i64)) + } + AstKind::Name(iname) => { + if iname.name == "N" { + let var = self + .variables + .get("model_index") + .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; + Ok(self.builder.use_var(*var)) + } else { + Err(anyhow!( + "unsupported name '{}' in integer expression", + iname.name + )) + } + } + AstKind::Monop(monop) => { + let child = self.jit_compile_integer_expr(monop.child.as_ref())?; + match monop.op { + '+' => Ok(child), + '-' => Ok(self.builder.ins().ineg(child)), + _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)), + } + } + AstKind::Binop(binop) => { + let lhs = self.jit_compile_integer_expr(binop.left.as_ref())?; + let rhs = self.jit_compile_integer_expr(binop.right.as_ref())?; + match binop.op { + '+' => Ok(self.builder.ins().iadd(lhs, rhs)), + '-' => Ok(self.builder.ins().isub(lhs, rhs)), + '*' => Ok(self.builder.ins().imul(lhs, rhs)), + '/' => Ok(self.builder.ins().sdiv(lhs, rhs)), + '%' => Ok(self.builder.ins().srem(lhs, rhs)), + _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)), + } + } + _ => Err(anyhow!("unsupported integer expression '{}'", expr)), + } + } + fn get_function_name(name: &str, is_tangent: bool) -> String { if is_tangent { format!("{name}__tangent__") diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index e554bba..18dae4b 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -30,6 +30,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, rr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -42,6 +43,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -54,6 +56,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -65,6 +68,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -76,6 +80,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -133,6 +138,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, out: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -145,6 +151,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -157,6 +164,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -168,6 +176,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -179,6 +188,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -188,6 +198,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, root: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index f9b1495..5a3c903 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -12,6 +12,7 @@ pub type StopFunc = unsafe extern "C" fn( u: *const T, data: *mut T, root: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -20,6 +21,7 @@ pub type RhsFunc = unsafe extern "C" fn( u: *const T, data: *mut T, rr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -31,6 +33,7 @@ pub type RhsGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -42,6 +45,7 @@ pub type RhsRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -52,6 +56,7 @@ pub type RhsSensGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -62,6 +67,7 @@ pub type RhsSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -115,6 +121,7 @@ pub type CalcOutFunc = unsafe extern "C" fn( u: *const T, data: *mut T, out: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -126,6 +133,7 @@ pub type CalcOutGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -137,6 +145,7 @@ pub type CalcOutRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -147,6 +156,7 @@ pub type CalcOutSensGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -157,6 +167,7 @@ pub type CalcOutSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index eff7b1d..566486c 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -341,6 +341,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "rhs_grad", @@ -355,6 +356,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "calc_out_grad", @@ -404,6 +406,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "rhs_rgrad", @@ -417,6 +420,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "calc_out_rgrad", @@ -441,6 +445,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "rhs_sgrad", @@ -467,6 +472,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "calc_out_sgrad", @@ -480,6 +486,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ReverseSens, "calc_out_srgrad", @@ -494,6 +501,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ReverseSens, "rhs_srgrad", @@ -2602,8 +2610,8 @@ impl<'ctx> CodeGen<'ctx> { } AstKind::Number(value) => Ok(self.real_type.const_float(*value)), AstKind::Name(iname) => { - let ptr = self.get_param(iname.name); - let layout = self.layout.get_layout(iname.name).unwrap(); + let ptr = *self.get_param(iname.name); + let layout = self.layout.get_layout(iname.name).unwrap().clone(); let iname_elmt_index = if layout.is_dense() { // permute indices based on the index chars of this tensor let mut no_transform = true; @@ -2617,14 +2625,12 @@ impl<'ctx> CodeGen<'ctx> { .position(|x| x == c) .unwrap_or(elmt.indices().len()); // if we are indexing, add the start indice to index[pi] - if let Some(indice) = - iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap()) - { - let start = indice.first.as_ref().kind.as_integer().unwrap(); - let start_intval = self - .context - .i32_type() - .const_int(start.try_into().unwrap(), false); + if let Some(indice_ast) = iname.indice.as_ref() { + let Some(indice) = indice_ast.kind.as_indice() else { + return Err(anyhow!("invalid index expression '{}'", indice_ast)); + }; + let start_intval = + self.jit_compile_integer_expr(indice.first.as_ref(), name)?; // if we are indexing a single element, the index may be out of bounds let index_pi = if pi >= index.len() { self.context.i32_type().const_int(0, false) @@ -2635,7 +2641,12 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_int_add(index_pi, start_intval, name)?; iname_index.push(index_pi); } else { - iname_index.push(index[pi]); + let index_pi = if pi >= index.len() { + self.context.i32_type().const_int(0, false) + } else { + index[pi] + }; + iname_index.push(index_pi); } no_transform = no_transform && pi == i; } @@ -2677,7 +2688,7 @@ impl<'ctx> CodeGen<'ctx> { } else if layout.is_sparse() || layout.is_diagonal() { let expr_layout = elmt.expr_layout(); - if expr_layout != layout { + if expr_layout != &layout { // get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ] // if its a -1 then return a 0 // ie. expr_index = binary_layout[expr_index] @@ -2685,10 +2696,10 @@ impl<'ctx> CodeGen<'ctx> { //. otherwise load the value at that index // we are doing an if statement so I think we need to return early here let permutation = - DataLayout::permutation(elmt, iname.indices.as_slice(), layout); + DataLayout::permutation(elmt, iname.indices.as_slice(), &layout); if let Some(base_binary_layout_index) = self.layout - .get_binary_layout_index(layout, expr_layout, permutation) + .get_binary_layout_index(&layout, expr_layout, permutation) { let binary_layout_index = self.builder.build_int_add( self.int_type @@ -2736,7 +2747,7 @@ impl<'ctx> CodeGen<'ctx> { let value_ptr = Self::get_ptr_to_index( &self.builder, self.real_type, - ptr, + &ptr, mapped_index, name, ); @@ -2767,7 +2778,7 @@ impl<'ctx> CodeGen<'ctx> { let value_ptr = Self::get_ptr_to_index( &self.builder, self.real_type, - ptr, + &ptr, iname_elmt_index, name, ); @@ -2789,6 +2800,69 @@ impl<'ctx> CodeGen<'ctx> { } } + fn jit_compile_integer_expr(&mut self, expr: &Ast, name: &str) -> Result> { + match &expr.kind { + AstKind::Integer(value) => Ok(self.int_type.const_int(*value as u64, true)), + AstKind::Number(value) => { + if value.fract() != 0.0 { + return Err(anyhow!( + "non-integer value '{}' in integer expression", + value + )); + } + Ok(self.int_type.const_int(*value as u64, true)) + } + AstKind::Name(iname) => { + if iname.name == "N" { + Ok(self + .build_load(self.int_type, *self.get_param("model_index"), name)? + .into_int_value()) + } else { + Err(anyhow!( + "unsupported name '{}' in integer expression", + iname.name + )) + } + } + AstKind::Monop(monop) => { + let child = self.jit_compile_integer_expr(monop.child.as_ref(), name)?; + match monop.op { + '+' => Ok(child), + '-' => self.builder.build_int_neg(child, name).map_err(Into::into), + _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)), + } + } + AstKind::Binop(binop) => { + let lhs = self.jit_compile_integer_expr(binop.left.as_ref(), name)?; + let rhs = self.jit_compile_integer_expr(binop.right.as_ref(), name)?; + match binop.op { + '+' => self + .builder + .build_int_add(lhs, rhs, name) + .map_err(Into::into), + '-' => self + .builder + .build_int_sub(lhs, rhs, name) + .map_err(Into::into), + '*' => self + .builder + .build_int_mul(lhs, rhs, name) + .map_err(Into::into), + '/' => self + .builder + .build_int_signed_div(lhs, rhs, name) + .map_err(Into::into), + '%' => self + .builder + .build_int_signed_rem(lhs, rhs, name) + .map_err(Into::into), + _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)), + } + } + _ => Err(anyhow!("unsupported integer expression '{}'", expr)), + } + } + fn clear(&mut self) { self.variables.clear(); //self.functions.clear(); @@ -2808,6 +2882,9 @@ impl<'ctx> CodeGen<'ctx> { .into_float_value(); let u = *self.get_param("u"); let data = *self.get_param("data"); + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); let thread_id = self .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")? .into_int_value(); @@ -2822,6 +2899,7 @@ impl<'ctx> CodeGen<'ctx> { t.into(), u.into(), data.into(), + model_index.into(), thread_id.into(), thread_dim.into(), barrier_start.into(), @@ -2969,7 +3047,15 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "out", + "model_index", + "thread_id", + "thread_dim", + ]; let function_name = if include_constants { "calc_out_full" } else { @@ -2985,6 +3071,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -2993,7 +3080,7 @@ impl<'ctx> CodeGen<'ctx> { // add noalias let alias_id = Attribute::get_named_enum_kind_id("noalias"); let noalign = self.context.create_enum_attribute(alias_id, 0); - for i in &[1, 2] { + for i in &[1, 2, 3] { function.add_attribute(AttributeLoc::Param(*i), noalign); } @@ -3079,6 +3166,7 @@ impl<'ctx> CodeGen<'ctx> { "t", "u", "data", + "model_index", "thread_id", "thread_dim", "barrier_start", @@ -3095,6 +3183,7 @@ impl<'ctx> CodeGen<'ctx> { self.int_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3169,7 +3258,15 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "root", + "model_index", + "thread_id", + "thread_dim", + ]; let function = self.add_function( "calc_stop", fn_arg_names, @@ -3180,6 +3277,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3254,7 +3352,15 @@ impl<'ctx> CodeGen<'ctx> { let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "rr", + "model_index", + "thread_id", + "thread_dim", + ]; let function_name = if include_constants { "rhs_full" } else { "rhs" }; let function = self.add_function( function_name, @@ -3266,6 +3372,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3319,8 +3426,8 @@ impl<'ctx> CodeGen<'ctx> { // F let res_ptr = self.get_param("rr"); self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?; - let barrier_num = self.int_type.const_int(nbarriers + 1, false); let total_barriers_val = self.int_type.const_int(total_barriers, false); + let barrier_num = self.int_type.const_int(nbarriers + 1, false); self.jit_compile_call_barrier(barrier_num, total_barriers_val); self.builder.build_return(None)?; @@ -3439,7 +3546,8 @@ impl<'ctx> CodeGen<'ctx> { let mut start_param_index: Vec = Vec::new(); let mut ptr_arg_indices: Vec = Vec::new(); for (i, arg) in original_function.get_param_iter().enumerate() { - start_param_index.push(u32::try_from(fn_type.len()).unwrap()); + let start_index = u32::try_from(fn_type.len()).unwrap(); + start_param_index.push(start_index); let arg_type = arg.get_type(); fn_type.push(arg_type.into()); @@ -3448,13 +3556,16 @@ impl<'ctx> CodeGen<'ctx> { enzyme_fn_type.push(arg.get_type().into()); if arg_type.is_pointer_type() { - ptr_arg_indices.push(u32::try_from(i).unwrap()); + ptr_arg_indices.push(start_index); } match args_type[i] { CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => { fn_type.push(arg.get_type().into()); enzyme_fn_type.push(arg.get_type().into()); + if arg_type.is_pointer_type() { + ptr_arg_indices.push(start_index + 1); + } } CompileGradientArgType::Const => {} } diff --git a/diffsl/src/parser/ds_grammar.pest b/diffsl/src/parser/ds_grammar.pest index c6bf354..edf66e5 100644 --- a/diffsl/src/parser/ds_grammar.pest +++ b/diffsl/src/parser/ds_grammar.pest @@ -9,14 +9,20 @@ assignment = { name ~ "=" ~ expression } expression = { term ~ (term_op ~ term)* } term = { factor ~ (factor_op ~ factor)* } factor = { sign? ~ ( call | real | integer | name_ij_index | name_ij | "(" ~ expression ~ ")" ) } +integer_expression = { integer_term ~ (term_op ~ integer_term)* } +integer_term = { integer_factor ~ (integer_factor_op ~ integer_factor)* } +integer_factor = { sign? ~ ( integer | integer_name | "(" ~ integer_expression ~ ")" ) } call = { name ~ "(" ~ call_arg ~ ("," ~ call_arg )* ~ ")" } call_arg = { expression } name_ij = ${ name ~ ("_" ~ name)? } -name_ij_index = ${ name_ij ~ "[" ~ indice ~ "]" } +index_indice = { integer_expression ~ ( range_sep ~ integer_expression )? } +name_ij_index = { name_ij ~ "[" ~ index_indice ~ "]" } range_sep = @{ ".." | ":" } sign = @{ ("-"|"+") } term_op = @{ "-"|"+" } factor_op = @{ "*"|"/" } +integer_factor_op = @{ "*"|"/"|"%" } +integer_name = @{ "N" } name = @{ ( 'a'..'z' | 'A'..'Z' ) ~ ('a'..'z' | 'A'..'Z' | '0'..'9' )* } integer = @{ ('0'..'9')+ } real = @{ ('0'..'9')+ ~ ( "." ~ ('0'..'9')* )? ~ ( "e" ~ sign? ~ integer )? } @@ -26,4 +32,3 @@ COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" | "//" ~ (!NEWLINE ~ ANY)* } - diff --git a/diffsl/src/parser/ds_parser.rs b/diffsl/src/parser/ds_parser.rs index 8f9c710..6a228ba 100644 --- a/diffsl/src/parser/ds_parser.rs +++ b/diffsl/src/parser/ds_parser.rs @@ -37,8 +37,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { pos_end: pair.as_span().end(), }); match pair.as_rule() { - // name = @{ 'a'..'z' ~ ("_" | 'a'..'z' | 'A'..'Z' | '0'..'9')* } - Rule::name => Ast { + // name = @{ ... } + // integer_name = @{ "N" } + Rule::name | Rule::integer_name => Ast { kind: AstKind::Name(ast::Name { name: pair.as_str(), indice: None, @@ -85,12 +86,13 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } } - //expression = { term ~ (term_op ~ term)* } - Rule::expression => { + // expression = { term ~ (term_op ~ term)* } + // integer_expression = { integer_term ~ (term_op ~ integer_term)* } + Rule::expression | Rule::integer_expression => { let mut inner = pair.into_inner(); let mut head_term = parse_value(inner.next().unwrap()); while inner.peek().is_some() { - //term_op = @{ "-"|"+" } + // term_op = @{ "-"|"+" } let term_op = parse_sign(inner.next().unwrap()); let rhs_term = parse_value(inner.next().unwrap()); let subspan = Some(StringSpan { @@ -109,8 +111,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { head_term } - //term = { factor ~ (factor_op ~ factor)* } - Rule::term => { + // term = { factor ~ (factor_op ~ factor)* } + // integer_term = { integer_factor ~ (integer_factor_op ~ integer_factor)* } + Rule::term | Rule::integer_term => { let mut inner = pair.into_inner(); let mut head_factor = parse_value(inner.next().unwrap()); while inner.peek().is_some() { @@ -132,8 +135,9 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { head_factor } - // factor = { sign? ~ (call | name | real | integer | "(" ~ expression ~ ")" ) } - Rule::factor => { + // factor = { sign? ~ (call | real | integer | name_ij_index | name_ij | "(" ~ expression ~ ")" ) } + // integer_factor = { sign? ~ ( integer | name | "(" ~ integer_expression ~ ")" ) } + Rule::factor | Rule::integer_factor => { let mut inner = pair.into_inner(); let sign = if inner.peek().unwrap().as_rule() == Rule::sign { Some(parse_sign(inner.next().unwrap())) @@ -152,7 +156,7 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } } - // name_ij_index = ${ name_ij ~ "[" ~ indice ~ "]" } + // name_ij_index = ${ name_ij ~ "[" ~ index_indice ~ "]" } Rule::name_ij_index => { let mut inner = pair.into_inner(); let mut name_ij = parse_value(inner.next().unwrap()); @@ -232,7 +236,8 @@ fn parse_value(pair: Pair<'_, Rule>) -> Ast<'_> { } // indice = { integer ~ ( range_sep ~ integer )? } - Rule::indice => { + // index_indice = { integer_expression ~ ( range_sep ~ integer_expression )? } + Rule::indice | Rule::index_indice => { let mut inner = pair.into_inner(); let first = Box::new(parse_value(inner.next().unwrap())); if inner.peek().is_some() { @@ -488,4 +493,64 @@ mod tests { assert_eq!(assignment.name, "y"); assert_eq!(assignment.expr.to_string(), "3"); } + + #[test] + fn index_expression_with_modulo() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[((N + 4 / 2) * 3 - 1) % 2] - x, 1 } + stop_i { 10 - tclock } + "; + let model = parse_string(TEXT).unwrap(); + assert_eq!(model.tensors.len(), 4); + let f_tensor = model.tensors[2].kind.as_tensor().unwrap(); + let f_expr = f_tensor.elmts()[0] + .kind + .as_tensor_elmt() + .unwrap() + .expr + .to_string(); + assert!(f_expr.contains("%")); + } + + #[test] + fn float_not_allowed_in_integer_expression_index() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[N % 1.1] - x, 1 } + stop_i { 10 - tclock } + "; + assert!(parse_string(TEXT).is_err()); + } + + #[test] + fn non_n_name_not_allowed_in_integer_expression_index() { + const TEXT: &str = " + amp_i { 0, 10 } + u_i { x = 1, tclock = 0 } + F_i { amp_i[x % 2] - x, 1 } + stop_i { 10 - tclock } + "; + assert!(parse_string(TEXT).is_err()); + } + + #[test] + fn slice_index_still_parses() { + const TEXT: &str = " + a_i { 0.0, 1.0, 2.0, 3.0 } + r_i { a_i[1:3] } + "; + let model = parse_string(TEXT).unwrap(); + assert_eq!(model.tensors.len(), 2); + let r_tensor = model.tensors[1].kind.as_tensor().unwrap(); + let expr = r_tensor.elmts()[0] + .kind + .as_tensor_elmt() + .unwrap() + .expr + .to_string(); + assert!(expr.contains("[1:3]")); + } } diff --git a/diffsl/tests/pybamm_dfn.rs b/diffsl/tests/pybamm_dfn.rs index 1d57331..1b65bb1 100644 --- a/diffsl/tests/pybamm_dfn.rs +++ b/diffsl/tests/pybamm_dfn.rs @@ -39,12 +39,12 @@ fn test_dfn_model_initialization() { compiler.set_u0(&mut u, &mut data); let mut rr = vec![0.0; n_states]; let t = 0.0; - compiler.rhs(t, &u, &mut data, &mut rr); + compiler.rhs(t, &u, &mut data, &mut rr, 0); let v = vec![1.; n_states]; let mut drr = vec![0.0; n_states]; let mut ddata = compiler.get_new_data(); println!("Computing rhs grad..."); // flush stdout to ensure the print appears before any potential panic std::io::stdout().flush().unwrap(); - compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr); + compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr, 0); } diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 0807ed9..4b4abbb 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -33,6 +33,7 @@ macro_rules! define_external_test { u: *const $ty, data: *mut $ty, rr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -53,6 +54,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -75,6 +77,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -95,6 +98,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -114,6 +118,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -197,6 +202,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, out: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -215,6 +221,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -234,6 +241,7 @@ macro_rules! define_external_test { _ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -251,6 +259,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -270,6 +279,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -287,6 +297,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, root: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -401,15 +412,15 @@ macro_rules! define_external_test { assert_eq!(u[0], 1.0 as $ty); let mut out = vec![-3.0 as $ty; n_outputs]; - compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out); + compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out, 0); assert_eq!(out[0], u[0]); let mut rr = vec![-4.0 as $ty; n_states]; - compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr); + compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr, 0); assert_eq!(rr[0], 0.0 as $ty); let mut stop = vec![-5.0 as $ty; n_stop]; - compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop); + compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop, 0); assert_eq!(stop[0], 0.5 as $ty); let mut mv = vec![-6.0 as $ty; n_states]; @@ -423,12 +434,12 @@ macro_rules! define_external_test { let du = vec![1.0 as $ty; n_states]; let mut ddata = vec![-8.0 as $ty; n_data]; let mut drr = vec![-9.0 as $ty; n_states]; - compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr); + compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr, 0); assert_eq!(drr[0], -1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); let mut dout = vec![-10.0 as $ty; n_outputs]; - compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout); + compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout, 0); assert_eq!(dout[0], 1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); @@ -439,15 +450,7 @@ macro_rules! define_external_test { let mut du_rev = vec![-11.0 as $ty; n_states]; let mut ddata_rev = vec![-12.0 as $ty; n_data]; let mut drr_rev = vec![1.0 as $ty; n_states]; - compiler.rhs_rgrad( - 0.0 as $ty, - &u, - &mut du_rev, - &data, - &mut ddata_rev, - &rr, - &mut drr_rev, - ); + compiler.rhs_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &rr, &mut drr_rev,, 0); assert_eq!(du_rev[0], -12.0 as $ty); assert_eq!(ddata_rev[0], -12.0 as $ty); @@ -457,15 +460,7 @@ macro_rules! define_external_test { assert_eq!(dv[0], -12.0 as $ty); let mut dout_rev = vec![1.0 as $ty; n_outputs]; - compiler.calc_out_rgrad( - 0.0 as $ty, - &u, - &mut du_rev, - &data, - &mut ddata_rev, - &out, - &mut dout_rev, - ); + compiler.calc_out_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &out, &mut dout_rev,, 0); assert_eq!(du_rev[0], -11.0 as $ty); compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev); @@ -473,22 +468,22 @@ macro_rules! define_external_test { let mut ddata_s = vec![-14.0 as $ty; n_data]; let mut drr_s = vec![-15.0 as $ty; n_states]; - compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s); + compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s, 0); assert_eq!(drr_s[0], 0.0 as $ty); assert_eq!(ddata_s[0], 0.0 as $ty); let mut dout_s = vec![-16.0 as $ty; n_outputs]; - compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s); + compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s, 0); assert_eq!(dout_s[0], 0.0 as $ty); let mut ddata_sr = vec![-17.0 as $ty; n_data]; let mut drr_sr = vec![-18.0 as $ty; n_states]; - compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr); + compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr, 0); assert_eq!(drr_sr[0], 0.0 as $ty); assert_eq!(ddata_sr[0], 0.0 as $ty); let mut dout_sr = vec![-19.0 as $ty; n_outputs]; - compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr); + compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr, 0); assert_eq!(dout_sr[0], 0.0 as $ty); } }; From f55575e10b09f990f6cfd81dd5f5323689233078 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 8 Mar 2026 01:19:05 +0000 Subject: [PATCH 2/3] feat: N can be used in general expressions --- diffsl/src/ast/mod.rs | 2 +- diffsl/src/discretise/discrete_model.rs | 5 +++++ diffsl/src/discretise/env.rs | 13 +++++++++++++ diffsl/src/execution/compiler.rs | 2 ++ diffsl/src/execution/cranelift/codegen.rs | 14 ++++++++++++++ diffsl/src/execution/data_layout.rs | 6 ++++++ diffsl/src/execution/llvm/codegen.rs | 14 ++++++++++++++ 7 files changed, 55 insertions(+), 1 deletion(-) diff --git a/diffsl/src/ast/mod.rs b/diffsl/src/ast/mod.rs index 667fd7f..e369b4f 100644 --- a/diffsl/src/ast/mod.rs +++ b/diffsl/src/ast/mod.rs @@ -580,7 +580,7 @@ impl<'a> Ast<'a> { } AstKind::CallArg(arg) => Self::new_call_arg(arg.name, arg.expression.tangent()), AstKind::Name(name) => { - if name.name == "t" { + if name.name == "t" || name.name == "N" { Self::new_number(0.0) } else { Ast { diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index 5249d70..41c2253 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -132,6 +132,7 @@ impl<'s> DiscreteModel<'s> { let reserved_names = [ "u0", "t", + "N", "data", "root", "thread_id", @@ -1401,6 +1402,7 @@ mod tests { tensor_fail_tests!( error_scalar: "r {1, 2}" errors ["cannot have more than one element in a scalar",], + error_reserved_name_n: "N { 1 }" errors ["N is a reserved name",], error_cannot_find: "r { k }" errors ["cannot find variable k",], error_different_shape: "a_i { 1, 2 } b_i { 1, 2, 3 } c_i { a_i + b_i }" errors ["cannot broadcast shapes: [2], [3]",], too_many_indices: "A_i { 1, 2 } B_i { (0:2): A_ij }" errors ["too many permutation indices",], @@ -1447,6 +1449,9 @@ mod tests { index: "A_i { 0.0, 1.0, 2.0 } B { A_i[1] }" expect "B" = "B { (): A_i[1] }", index2: "A_i { 0.0, 1.0, 2.0 } B_i { A_i[1:3] }" expect "B" = "B_i(2) { (0)(2):A_i[1:3](2) }", index3: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[0:2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[0:2] (2, 2) }", + index_n_scalar: "A_i { 0.0, 1.0, 2.0 } B { A_i[N % 2] }" expect "B" = "B { (): A_i[N % 2] }", + index_n_range: "A_i { 0.0, 1.0, 2.0, 3.0 } B_i { A_i[N:N+2] }" expect "B" = "B_i(2) { (0)(2):A_i[N:N + 2](2) }", + index_n_range_contraction: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[N:N+2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[N:N + 2] (2, 2) }", prefix_minus: "A { 1.0 / -2.0 }" expect "A" = "A { (): 1 / (-2) }", time: "A_i { t }" expect "A" = "A_i (1) { (0)(1): t }", named_blk: "A_i { (0:3): y = 1, 2 }" expect "A" = "A_i (4) { (0)(3): y = 1, (3)(1): 2 }", diff --git a/diffsl/src/discretise/env.rs b/diffsl/src/discretise/env.rs index f0d8500..d50d67f 100644 --- a/diffsl/src/discretise/env.rs +++ b/diffsl/src/discretise/env.rs @@ -191,6 +191,19 @@ impl Env { is_algebraic: true, }, ); + vars.insert( + "N".to_string(), + EnvVar { + layout: ArcLayout::new(Layout::new_scalar()), + // `N` varies per-model at runtime (from `model_index`), so definitions that + // depend on it must not be treated as compile-time constants. + is_time_dependent: true, + is_state_dependent: false, + is_dstatedt_dependent: false, + is_input_dependent: false, + is_algebraic: true, + }, + ); Env { errs: ValidationErrors::default(), vars, diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index e66a342..f4c880a 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -1867,6 +1867,8 @@ mod tests { abs_function: "r { abs(-2) }" expect "r" vec![f64::abs(-2.0)], min_function: "r { min(2, 3) }" expect "r" vec![2.0], max_function: "r { max(2, 3) }" expect "r" vec![3.0], + n_expression_scalar: "r { 2.0 * N + 1.0 }" expect "r" vec![1.0], + n_expression_vector: "r_i { 2.0 * N + 1.0, 3.0 * N + 2.0 }" expect "r" vec![1.0, 2.0], scalar: "r {2}" expect "r" vec![2.0,], constant: "r_i {2, 3}" expect "r" vec![2., 3.], derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.], diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index fd38244..9faa29d 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -1149,6 +1149,20 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { } AstKind::Number(value) => Ok(self.fconst(*value)), AstKind::Name(iname) => { + if iname.name == "N" { + if iname.is_tangent { + return Ok(self.fconst(0.0)); + } + let var = self + .variables + .get("model_index") + .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; + let model_index = self.builder.use_var(*var); + return Ok(self + .builder + .ins() + .fcvt_from_sint(self.real_type, model_index)); + } let ptr = if iname.is_tangent { // tangent of a constant is zero if self.layout.is_constant(iname.name) { diff --git a/diffsl/src/execution/data_layout.rs b/diffsl/src/execution/data_layout.rs index 1df40f8..be470c4 100644 --- a/diffsl/src/execution/data_layout.rs +++ b/diffsl/src/execution/data_layout.rs @@ -44,6 +44,12 @@ impl DataLayout { // add layout info for "t" let t_layout = ArcLayout::new(Layout::new_scalar()); layout_map.insert("t".to_string(), t_layout); + is_constant_map.insert("t".to_string(), false); + + // add layout info for model index "N" + let n_layout = ArcLayout::new(Layout::new_scalar()); + layout_map.insert("N".to_string(), n_layout); + is_constant_map.insert("N".to_string(), false); let mut add_tensor = |tensor: &Tensor, in_data: bool, in_constants: bool| { // insert the data (non-zeros) for each tensor diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 566486c..e18cc97 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -2610,6 +2610,20 @@ impl<'ctx> CodeGen<'ctx> { } AstKind::Number(value) => Ok(self.real_type.const_float(*value)), AstKind::Name(iname) => { + if iname.name == "N" { + if iname.is_tangent { + return Ok(self.real_type.const_float(0.0)); + } + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); + let n_value = self.builder.build_signed_int_to_float( + model_index, + self.real_type, + "n_as_real", + )?; + return Ok(n_value); + } let ptr = *self.get_param(iname.name); let layout = self.layout.get_layout(iname.name).unwrap().clone(); let iname_elmt_index = if layout.is_dense() { From ec8dddcccff74396820afc46bc3cf452e6b379dc Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 8 Mar 2026 15:38:59 +0000 Subject: [PATCH 3/3] refactor: move model_index arg to set_inputs only, this is stored and used in rhs etc --- diffsl/benches/evaluation.rs | 4 +- diffsl/benches/pybamm_dfn.rs | 3 +- diffsl/src/discretise/discrete_model.rs | 6 +- diffsl/src/discretise/env.rs | 26 +++- diffsl/src/execution/compiler.rs | 150 +++++++------------ diffsl/src/execution/cranelift/codegen.rs | 96 +++++++----- diffsl/src/execution/external/mod.rs | 15 +- diffsl/src/execution/interface.rs | 32 ++-- diffsl/src/execution/llvm/codegen.rs | 95 ++++++------ diffsl/tests/pybamm_dfn.rs | 6 +- diffsl/tests/support/external_test_macros.rs | 43 +++--- 11 files changed, 223 insertions(+), 253 deletions(-) diff --git a/diffsl/benches/evaluation.rs b/diffsl/benches/evaluation.rs index e38d848..380282a 100644 --- a/diffsl/benches/evaluation.rs +++ b/diffsl/benches/evaluation.rs @@ -55,14 +55,14 @@ fn execute( let n = N; let compiler = setup::(n, f_text, "execute"); let mut data = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice()); + compiler.set_inputs(&[], data.as_mut_slice(), 0); let mut u = vec![1.0; n]; compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); let mut rr = vec![0.0; n]; let t = 0.0; bencher.bench_local(|| { - compiler.rhs(t, &u, &mut data, &mut rr, 0); + compiler.rhs(t, &u, &mut data, &mut rr); }); } diff --git a/diffsl/benches/pybamm_dfn.rs b/diffsl/benches/pybamm_dfn.rs index e7cd228..7dd9af6 100644 --- a/diffsl/benches/pybamm_dfn.rs +++ b/diffsl/benches/pybamm_dfn.rs @@ -35,7 +35,6 @@ fn pybamm_dfn_execute_rhs_grad(bench ddata.as_mut_slice(), rr.as_mut_slice(), drr.as_mut_slice(), - 0, ); }); } @@ -68,7 +67,7 @@ fn pybamm_dfn_execute_rhs(bencher: B let mut data = compiler.get_new_data(); let mut rr = vec![0.0; n_states]; bencher.bench_local(move || { - compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); + compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); }); } diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index 41c2253..a7a95c8 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -503,7 +503,11 @@ impl<'s> DiscreteModel<'s> { let dependent_on_time = env_entry.is_time_dependent(); let dependent_on_dudt = env_entry.is_dstatedt_dependent(); let dependent_on_input = env_entry.is_input_dependent(); - if !dependent_on_time && !dependent_on_input { + let dependent_on_model = env_entry.is_model_dependent(); + if !dependent_on_time + && !dependent_on_input + && !dependent_on_model + { ret.constant_defns.push(built); } else if !dependent_on_time { ret.input_dep_defns.push(built); diff --git a/diffsl/src/discretise/env.rs b/diffsl/src/discretise/env.rs index d50d67f..f6c08bb 100644 --- a/diffsl/src/discretise/env.rs +++ b/diffsl/src/discretise/env.rs @@ -19,6 +19,7 @@ pub struct EnvVar { is_state_dependent: bool, is_dstatedt_dependent: bool, is_input_dependent: bool, + is_model_dependent: bool, is_algebraic: bool, } @@ -43,6 +44,10 @@ impl EnvVar { self.is_input_dependent } + pub fn is_model_dependent(&self) -> bool { + self.is_model_dependent + } + pub fn layout(&self) -> &Layout { self.layout.as_ref() } @@ -188,6 +193,7 @@ impl Env { is_state_dependent: false, is_dstatedt_dependent: false, is_input_dependent: false, + is_model_dependent: false, is_algebraic: true, }, ); @@ -195,12 +201,11 @@ impl Env { "N".to_string(), EnvVar { layout: ArcLayout::new(Layout::new_scalar()), - // `N` varies per-model at runtime (from `model_index`), so definitions that - // depend on it must not be treated as compile-time constants. - is_time_dependent: true, + is_time_dependent: false, is_state_dependent: false, is_dstatedt_dependent: false, is_input_dependent: false, + is_model_dependent: true, is_algebraic: true, }, ); @@ -243,6 +248,19 @@ impl Env { self.is_tensor_dependent_on(tensor, "in") } + pub fn is_tensor_model_dependent(&self, tensor: &Tensor) -> bool { + if tensor.name() == "N" { + return true; + } + tensor.elmts().iter().any(|block| { + block + .expr() + .get_dependents() + .iter() + .any(|&dep| dep == "N" || self.vars[dep].is_model_dependent()) + }) + } + pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool { self.is_tensor_dependent_on(tensor, "dudt") } @@ -275,6 +293,7 @@ impl Env { is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), is_input_dependent: self.is_tensor_input_dependent(var), + is_model_dependent: self.is_tensor_model_dependent(var), }, ); } @@ -289,6 +308,7 @@ impl Env { is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), is_input_dependent: self.is_tensor_input_dependent(var), + is_model_dependent: self.is_tensor_model_dependent(var), }, ); } diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index f4c880a..bf76dd0 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -410,7 +410,7 @@ impl Compiler { }) } - pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T], model_index: u32) { + pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T]) { if self.number_of_stop == 0 { panic!("Model does not have a stop function"); } @@ -423,14 +423,13 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, stop.as_ptr() as *mut T, - model_index, i, dim, ) }); } - pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T], model_index: u32) { + pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_data_len(data, "data"); @@ -440,7 +439,6 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, rr.as_ptr() as *mut T, - model_index, i, dim, ) @@ -488,7 +486,6 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -505,7 +502,6 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, i, dim, ) @@ -522,7 +518,6 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -543,7 +538,6 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, i, dim, ) @@ -574,16 +568,7 @@ impl Compiler { }); } - pub fn rhs_sgrad( - &self, - t: T, - yy: &[T], - data: &[T], - ddata: &mut [T], - rr: &[T], - drr: &mut [T], - model_index: u32, - ) { + pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -601,23 +586,13 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, i, dim, ) }); } - pub fn rhs_srgrad( - &self, - t: T, - yy: &[T], - data: &[T], - ddata: &mut [T], - rr: &[T], - drr: &mut [T], - model_index: u32, - ) { + pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -635,14 +610,13 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, i, dim, ) }); } - pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T], model_index: u32) { + pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T]) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); self.check_out_len(out, "out"); @@ -652,7 +626,6 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, out.as_ptr() as *mut T, - model_index, i, dim, ) @@ -669,7 +642,6 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -686,7 +658,6 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, i, dim, ) @@ -703,7 +674,6 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -724,7 +694,6 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, i, dim, ) @@ -739,7 +708,6 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -758,7 +726,6 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, i, dim, ) @@ -773,7 +740,6 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -792,7 +758,6 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, i, dim, ) @@ -831,10 +796,10 @@ impl Compiler { ) } - pub fn set_inputs(&self, inputs: &[T], data: &mut [T]) { + pub fn set_inputs(&self, inputs: &[T], data: &mut [T], model_index: u32) { self.check_inputs_len(inputs, "inputs"); self.check_data_len(data, "data"); - unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr()) }; + unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr(), model_index) }; } pub fn get_inputs(&self, inputs: &mut [T], data: &[T]) { @@ -843,7 +808,14 @@ impl Compiler { unsafe { (self.jit_functions.get_inputs)(inputs.as_mut_ptr(), data.as_ptr()) }; } - pub fn set_inputs_grad(&self, inputs: &[T], dinputs: &[T], data: &[T], ddata: &mut [T]) { + pub fn set_inputs_grad( + &self, + inputs: &[T], + dinputs: &[T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -854,11 +826,19 @@ impl Compiler { dinputs.as_ptr(), data.as_ptr(), ddata.as_mut_ptr(), + model_index, ) }; } - pub fn set_inputs_rgrad(&self, inputs: &[T], dinputs: &mut [T], data: &[T], ddata: &mut [T]) { + pub fn set_inputs_rgrad( + &self, + inputs: &[T], + dinputs: &mut [T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -873,6 +853,7 @@ impl Compiler { dinputs.as_mut_ptr(), data.as_ptr(), ddata.as_mut_ptr(), + model_index, ) }; } @@ -982,7 +963,7 @@ mod tests { assert_relative_eq!(a2[0], T::zero()); // set the inputs and u0 let inputs = vec![T::one()]; - compiler.set_inputs(&inputs, data.as_mut_slice()); + compiler.set_inputs(&inputs, data.as_mut_slice(), 0); let mut u0 = vec![T::zero()]; compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); // now a and a2 should be set @@ -1029,7 +1010,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), - 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); } @@ -1077,14 +1057,12 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), - 0, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), - 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); assert_eq!(stop.len(), 1); @@ -1134,35 +1112,33 @@ mod tests { let mut stop1 = vec![T::zero(); 1]; let mut data = compiler.get_new_data(); + compiler.set_inputs(&[], data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), - 0, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop0.as_mut_slice(), - 0, ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), - 1, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop1.as_mut_slice(), - 1, ); assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); @@ -1208,20 +1184,20 @@ mod tests { let mut ddata0 = compiler.get_new_data(); let mut ddata1 = compiler.get_new_data(); + compiler.set_inputs(&[], data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), - 0, ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), - 1, ); assert_relative_eq!(rr0[0], T::from_f64(3.0).unwrap()); @@ -1229,6 +1205,7 @@ mod tests { let dyy0 = vec![T::one(), T::zero()]; let dyy1 = vec![T::zero(), T::one()]; + compiler.set_inputs(&[], data.as_mut_slice(), 0); compiler.rhs_grad( T::zero(), u0.as_slice(), @@ -1237,8 +1214,8 @@ mod tests { ddata0.as_mut_slice(), rr0.as_slice(), drr0.as_mut_slice(), - 0, ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); compiler.rhs_grad( T::zero(), u0.as_slice(), @@ -1247,7 +1224,6 @@ mod tests { ddata1.as_mut_slice(), rr1.as_slice(), drr1.as_mut_slice(), - 1, ); assert_relative_eq!(drr0[0], T::one()); @@ -1281,20 +1257,20 @@ mod tests { let mut rr1 = vec![T::zero(); 2]; let mut data = compiler.get_new_data(); + compiler.set_inputs(&[], data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), - 0, ); + compiler.set_inputs(&[], data.as_mut_slice(), 1); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), - 1, ); assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); @@ -1333,7 +1309,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), - 0, ); assert_relative_eq!(out[0], T::from_f64(2.).unwrap()); u0[0] = T::from_f64(2.).unwrap(); @@ -1342,7 +1317,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), - 0, ); assert_relative_eq!(out[0], T::from_f64(4.).unwrap()); let mut stop = vec![T::zero()]; @@ -1351,7 +1325,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), - 0, ); assert_relative_eq!(stop[0], T::from_f64(3.5).unwrap()); u0[0] = T::from_f64(0.5).unwrap(); @@ -1360,7 +1333,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), - 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); } @@ -1436,21 +1408,19 @@ mod tests { let mut results = Vec::new(); let inputs = vec![T::one(); n_inputs]; let mut out = vec![T::zero(); n_outputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), - 0, ); compiler.calc_out( T::zero(), u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), - 0, ); let (tensor_len, tensor_is_constant) = if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, data.as_slice()) { @@ -1480,6 +1450,7 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.set_u0_grad( u0.as_mut_slice(), @@ -1495,7 +1466,6 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), - 0, ); compiler.calc_out_grad( T::zero(), @@ -1505,7 +1475,6 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), - 0, ); if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, ddata.as_slice()) { results.push(tensor_data.to_vec()); @@ -1534,7 +1503,6 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), - 0, ); compiler.rhs_rgrad( T::zero(), @@ -1544,7 +1512,6 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), - 0, ); compiler.set_u0_rgrad( u0.as_mut_slice(), @@ -1559,7 +1526,7 @@ mod tests { let mut ddata = compiler.get_new_data(); let mut dres = vec![T::zero(); n_states]; let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); compiler.rhs_sgrad( T::zero(), u0.as_slice(), @@ -1567,7 +1534,6 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), - 0, ); results.push( compiler @@ -1579,7 +1545,7 @@ mod tests { // forward mode sens (calc_out) let mut ddata = compiler.get_new_data(); let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); compiler.calc_out_sgrad( T::zero(), u0.as_slice(), @@ -1587,7 +1553,6 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), - 0, ); results.push( compiler @@ -1611,13 +1576,13 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), - 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); results.push(dinputs.to_vec()); @@ -1635,13 +1600,13 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), - 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); results.push(dinputs.to_vec()); } else { @@ -2163,14 +2128,13 @@ mod tests { let mut ddata = compiler.get_new_data(); let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop, _has_mass) = compiler.get_dims(); let inputs = vec![T::from_f64(2.).unwrap(); n_inputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), - 0, ); for _i in 0..3 { @@ -2180,6 +2144,7 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.set_u0_grad( u0.as_mut_slice(), @@ -2195,7 +2160,6 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), - 0, ); assert_relative_eq!(dres.as_slice(), vec![T::from_f64(8.).unwrap()].as_slice()); } @@ -2257,21 +2221,9 @@ mod tests { let mut data = compiler.get_new_data(); let (_n_states, _n_inputs, _n_outputs, _n_data, _n_stop, _has_mass) = compiler.get_dims(); compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); - compiler.rhs( - 0.0, - u.as_slice(), - data.as_mut_slice(), - res.as_mut_slice(), - 0, - ); + compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); - compiler.rhs( - 0.0, - u.as_slice(), - data.as_mut_slice(), - res.as_mut_slice(), - 0, - ); + compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); } @@ -2321,7 +2273,7 @@ mod tests { let mut data = compiler.get_new_data(); let inputs = vec![1.1]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.1].as_slice()); @@ -2335,7 +2287,7 @@ mod tests { assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice()); let mut rr = vec![1., 1.]; - compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); + compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice()); let up = vec![2., 3.]; @@ -2344,7 +2296,7 @@ mod tests { assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice()); let mut out = vec![0.; 3]; - compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice(), 0); + compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice()); assert_relative_eq!(out.as_slice(), vec![1., 2., 4.].as_slice()); } } @@ -2372,7 +2324,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2388,7 +2340,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2463,7 +2415,6 @@ mod tests { u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), - 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); }); @@ -2497,12 +2448,13 @@ mod tests { let mut ddata = compiler.get_new_data(); let a = vec![T::from_f64(0.6).unwrap()]; let da = vec![T::one()]; - compiler.set_inputs(a.as_slice(), data.as_mut_slice()); + compiler.set_inputs(a.as_slice(), data.as_mut_slice(), 0); compiler.set_inputs_grad( a.as_slice(), da.as_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); let mut u0 = vec![T::zero()]; let mut du0 = vec![T::zero()]; diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index 9faa29d..b0d3f02 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -35,6 +35,7 @@ pub struct CraneliftModule { indices_id: DataId, constants_id: DataId, + model_index_id: DataId, thread_counter: Option, //triple: Triple, @@ -214,7 +215,6 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, - self.int_type, ]; let arg_names = &[ "t", @@ -224,7 +224,6 @@ impl CraneliftModule { "ddata", "out", "dout", - "model_index", "threadId", "threadDim", ]; @@ -252,7 +251,6 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, - self.int_type, ]; let arg_names = &[ "t", @@ -262,7 +260,6 @@ impl CraneliftModule { "ddata", "rr", "drr", - "model_index", "threadId", "threadDim", ]; @@ -304,11 +301,24 @@ impl CraneliftModule { self.real_ptr_type, self.real_ptr_type, self.real_ptr_type, + self.int_type, ]; - let arg_names = &["inputs", "dinputs", "data", "ddata"]; + let arg_names = &["inputs", "dinputs", "data", "ddata", "model_index"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = codegen + .builder + .use_var(*codegen.variables.get("model_index").unwrap()); + codegen + .builder + .ins() + .store(codegen.mem_flags, model_index, model_index_ptr, 0); + let base_data_ptr = codegen.variables.get("ddata").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, true, false); @@ -430,6 +440,11 @@ impl CraneliftModule { let indices_id = module.declare_data("indices", Linkage::Local, false, false)?; module.define_data(indices_id, &data_description)?; + let mut data_description = DataDescription::new(); + data_description.define_zeroinit(int_type.bytes().try_into().unwrap()); + let model_index_id = module.declare_data("model_index", Linkage::Local, true, false)?; + module.define_data(model_index_id, &data_description)?; + let mut thread_counter = None; if threaded { let mut data_description = DataDescription::new(); @@ -446,6 +461,7 @@ impl CraneliftModule { module: Mutex::new(module), indices_id, constants_id, + model_index_id, int_type, real_type: real_type_cranelift, real_ptr_type: ptr_type, @@ -529,17 +545,8 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, - self.int_type, - ]; - let arg_names = &[ - "t", - "u", - "data", - "out", - "model_index", - "threadId", - "threadDim", ]; + let arg_names = &["t", "u", "data", "out", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -581,17 +588,8 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, - self.int_type, - ]; - let arg_names = &[ - "t", - "u", - "data", - "root", - "model_index", - "threadId", - "threadDim", ]; + let arg_names = &["t", "u", "data", "root", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -633,17 +631,8 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, - self.int_type, - ]; - let arg_names = &[ - "t", - "u", - "data", - "rr", - "model_index", - "threadId", - "threadDim", ]; + let arg_names = &["t", "u", "data", "rr", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -814,11 +803,23 @@ impl CraneliftModule { } fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { - let arg_types = &[self.real_ptr_type, self.real_ptr_type]; - let arg_names = &["inputs", "data"]; + let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_type]; + let arg_names = &["inputs", "data", "model_index"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = codegen + .builder + .use_var(*codegen.variables.get("model_index").unwrap()); + codegen + .builder + .ins() + .store(codegen.mem_flags, model_index, model_index_ptr, 0); + let base_data_ptr = codegen.variables.get("data").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, false, false); @@ -1055,6 +1056,7 @@ struct CraneliftCodeGen<'a, M: Module> { layout: &'a DataLayout, indices: GlobalValue, constants: GlobalValue, + model_index_global: GlobalValue, threaded: bool, } @@ -2354,6 +2356,12 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { .unwrap() .declare_data_in_func(module.constants_id, builder.func); + let model_index_global = module + .module + .lock() + .unwrap() + .declare_data_in_func(module.model_index_id, builder.func); + // Create the entry block, to start emitting code in. let entry_block = builder.create_block(); @@ -2386,6 +2394,7 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { functions: HashMap::new(), layout: &module.layout, threaded: module.threaded, + model_index_global, }; // insert arg vars @@ -2394,6 +2403,19 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { codegen.declare_variable(*arg_type, arg_name, val); } + if !codegen.variables.contains_key("model_index") { + let model_index_ptr = codegen + .builder + .ins() + .global_value(codegen.int_ptr_type, codegen.model_index_global); + let model_index = + codegen + .builder + .ins() + .load(codegen.int_type, codegen.mem_flags, model_index_ptr, 0); + codegen.declare_variable(codegen.int_type, "model_index", model_index); + } + // insert u if it exists in args if let Some(u) = codegen.variables.get("u") { let u_ptr = codegen.builder.use_var(*u); diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index 18dae4b..9a2ee0c 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -30,7 +30,6 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, rr: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -43,7 +42,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -56,7 +54,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -68,7 +65,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -80,7 +76,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -138,7 +133,6 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, out: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -151,7 +145,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -164,7 +157,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -176,7 +168,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -188,7 +179,6 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -198,7 +188,6 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, root: *mut $ty, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -214,7 +203,7 @@ macro_rules! define_symbol_module { has_mass: *mut UIntType, ); #[link_name = "set_inputs"] - pub fn set_inputs(inputs: *const $ty, data: *mut $ty); + pub fn set_inputs(inputs: *const $ty, data: *mut $ty, model_index: UIntType); #[link_name = "get_inputs"] pub fn get_inputs(inputs: *mut $ty, data: *const $ty); #[link_name = "set_inputs_grad"] @@ -223,6 +212,7 @@ macro_rules! define_symbol_module { dinputs: *const $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, ); #[link_name = "set_inputs_rgrad"] pub fn set_inputs_rgrad( @@ -230,6 +220,7 @@ macro_rules! define_symbol_module { dinputs: *mut $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, ); } } diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index 5a3c903..9ae9eeb 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -12,7 +12,6 @@ pub type StopFunc = unsafe extern "C" fn( u: *const T, data: *mut T, root: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -21,7 +20,6 @@ pub type RhsFunc = unsafe extern "C" fn( u: *const T, data: *mut T, rr: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -33,7 +31,6 @@ pub type RhsGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -45,7 +42,6 @@ pub type RhsRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -56,7 +52,6 @@ pub type RhsSensGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -67,7 +62,6 @@ pub type RhsSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -121,7 +115,6 @@ pub type CalcOutFunc = unsafe extern "C" fn( u: *const T, data: *mut T, out: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -133,7 +126,6 @@ pub type CalcOutGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -145,7 +137,6 @@ pub type CalcOutRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -156,7 +147,6 @@ pub type CalcOutSensGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -167,7 +157,6 @@ pub type CalcOutSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -179,12 +168,23 @@ pub type GetDimsFunc = unsafe extern "C" fn( stop: *mut UIntType, has_mass: *mut UIntType, ); -pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const T, data: *mut T); +pub type SetInputsFunc = + unsafe extern "C" fn(inputs: *const T, data: *mut T, model_index: UIntType); pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut T, data: *const T); -pub type SetInputsGradFunc = - unsafe extern "C" fn(inputs: *const T, dinputs: *const T, data: *const T, ddata: *mut T); -pub type SetInputsRevGradFunc = - unsafe extern "C" fn(inputs: *const T, dinputs: *mut T, data: *const T, ddata: *mut T); +pub type SetInputsGradFunc = unsafe extern "C" fn( + inputs: *const T, + dinputs: *const T, + data: *const T, + ddata: *mut T, + model_index: UIntType, +); +pub type SetInputsRevGradFunc = unsafe extern "C" fn( + inputs: *const T, + dinputs: *mut T, + data: *const T, + ddata: *mut T, + model_index: UIntType, +); pub type SetIdFunc = unsafe extern "C" fn(id: *mut T); pub type GetTensorFunc = unsafe extern "C" fn(data: *const T, tensor_data: *mut *mut T, tensor_size: *mut UIntType); diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index e18cc97..650f197 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -341,7 +341,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::Forward, "rhs_grad", @@ -356,7 +355,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::Forward, "calc_out_grad", @@ -366,6 +364,7 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, ], CompileMode::Forward, "set_inputs_grad", @@ -406,7 +405,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::Reverse, "rhs_rgrad", @@ -420,7 +418,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::Reverse, "calc_out_rgrad", @@ -431,6 +428,7 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, ], CompileMode::Reverse, "set_inputs_rgrad", @@ -445,7 +443,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::ForwardSens, "rhs_sgrad", @@ -472,7 +469,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::ForwardSens, "calc_out_sgrad", @@ -486,7 +482,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::ReverseSens, "calc_out_srgrad", @@ -501,7 +496,6 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, - CompileGradientArgType::Const, ], CompileMode::ReverseSens, "rhs_srgrad", @@ -555,6 +549,7 @@ struct Globals<'ctx> { indices: Option>, constants: Option>, thread_counter: Option>, + model_index: GlobalValue<'ctx>, } impl<'ctx> Globals<'ctx> { @@ -620,10 +615,19 @@ impl<'ctx> Globals<'ctx> { indices.set_initializer(&indices_value); Some(indices) }; + let model_index = module.add_global( + int_type, + Some(AddressSpace::default()), + "enzyme_const_model_index", + ); + model_index.set_visibility(GlobalVisibility::Hidden); + model_index.set_constant(false); + model_index.set_initializer(&int_type.const_zero()); Self { indices, thread_counter, constants, + model_index, } } } @@ -1311,6 +1315,7 @@ impl<'ctx> CodeGen<'ctx> { } fn insert_data(&mut self, model: &DiscreteModel) { + self.insert_model_index(); self.insert_constants(model); if let Some(input) = model.input() { @@ -1351,6 +1356,10 @@ impl<'ctx> CodeGen<'ctx> { } } + fn insert_model_index(&mut self) { + self.insert_param("model_index", self.globals.model_index.as_pointer_value()); + } + fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) { self.variables.insert(name.to_owned(), value); } @@ -2896,9 +2905,6 @@ impl<'ctx> CodeGen<'ctx> { .into_float_value(); let u = *self.get_param("u"); let data = *self.get_param("data"); - let model_index = self - .build_load(self.int_type, *self.get_param("model_index"), "model_index")? - .into_int_value(); let thread_id = self .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")? .into_int_value(); @@ -2913,7 +2919,6 @@ impl<'ctx> CodeGen<'ctx> { t.into(), u.into(), data.into(), - model_index.into(), thread_id.into(), thread_dim.into(), barrier_start.into(), @@ -3061,15 +3066,7 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &[ - "t", - "u", - "data", - "out", - "model_index", - "thread_id", - "thread_dim", - ]; + let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"]; let function_name = if include_constants { "calc_out_full" } else { @@ -3085,7 +3082,6 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), - self.int_type.into(), ], None, false, @@ -3180,7 +3176,6 @@ impl<'ctx> CodeGen<'ctx> { "t", "u", "data", - "model_index", "thread_id", "thread_dim", "barrier_start", @@ -3197,7 +3192,6 @@ impl<'ctx> CodeGen<'ctx> { self.int_type.into(), self.int_type.into(), self.int_type.into(), - self.int_type.into(), ], None, false, @@ -3272,15 +3266,7 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &[ - "t", - "u", - "data", - "root", - "model_index", - "thread_id", - "thread_dim", - ]; + let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"]; let function = self.add_function( "calc_stop", fn_arg_names, @@ -3291,7 +3277,6 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), - self.int_type.into(), ], None, false, @@ -3366,15 +3351,7 @@ impl<'ctx> CodeGen<'ctx> { let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; self.clear(); - let fn_arg_names = &[ - "t", - "u", - "data", - "rr", - "model_index", - "thread_id", - "thread_dim", - ]; + let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"]; let function_name = if include_constants { "rhs_full" } else { "rhs" }; let function = self.add_function( function_name, @@ -3386,7 +3363,6 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), - self.int_type.into(), ], None, false, @@ -3993,14 +3969,21 @@ impl<'ctx> CodeGen<'ctx> { ) -> Result> { self.clear(); let function_name = if is_get { "get_inputs" } else { "set_inputs" }; - let fn_arg_names = &["inputs", "data"]; - let function = self.add_function( - function_name, - fn_arg_names, - &[self.real_ptr_type.into(), self.real_ptr_type.into()], - None, - false, - ); + let fn_arg_names: &[&str] = if is_get { + &["inputs", "data"] + } else { + &["inputs", "data", "model_index"] + }; + let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get { + &[self.real_ptr_type.into(), self.real_ptr_type.into()] + } else { + &[ + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.int_type.into(), + ] + }; + let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false); let block = self.start_function(function, None); for (i, arg) in function.get_param_iter().enumerate() { @@ -4009,6 +3992,14 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } + if !is_get { + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); + self.builder + .build_store(self.globals.model_index.as_pointer_value(), model_index)?; + } + if let Some(input) = model.input() { let name = input.name(); self.insert_tensor(input, false); diff --git a/diffsl/tests/pybamm_dfn.rs b/diffsl/tests/pybamm_dfn.rs index 1b65bb1..94e62a3 100644 --- a/diffsl/tests/pybamm_dfn.rs +++ b/diffsl/tests/pybamm_dfn.rs @@ -34,17 +34,17 @@ fn test_dfn_model_initialization() { let mut data = compiler.get_new_data(); let (n_states, n_inputs, _, _, _, _) = compiler.get_dims(); let inputs = vec![1.0; n_inputs]; - compiler.set_inputs(&inputs, &mut data); + compiler.set_inputs(&inputs, &mut data, 0); let mut u = vec![0.0; n_states]; compiler.set_u0(&mut u, &mut data); let mut rr = vec![0.0; n_states]; let t = 0.0; - compiler.rhs(t, &u, &mut data, &mut rr, 0); + compiler.rhs(t, &u, &mut data, &mut rr); let v = vec![1.; n_states]; let mut drr = vec![0.0; n_states]; let mut ddata = compiler.get_new_data(); println!("Computing rhs grad..."); // flush stdout to ensure the print appears before any potential panic std::io::stdout().flush().unwrap(); - compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr, 0); + compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr); } diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 4b4abbb..0054f02 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -33,7 +33,6 @@ macro_rules! define_external_test { u: *const $ty, data: *mut $ty, rr: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -54,7 +53,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -77,7 +75,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -98,7 +95,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -118,7 +114,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -202,7 +197,6 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, out: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -221,7 +215,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -241,7 +234,6 @@ macro_rules! define_external_test { _ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -259,7 +251,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -279,7 +270,6 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -297,7 +287,6 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, root: *mut $ty, - _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -344,7 +333,7 @@ macro_rules! define_external_test { } #[no_mangle] - pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty) { + pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty, _model_index: u32) { if inputs.is_null() || data.is_null() { return; } @@ -365,6 +354,7 @@ macro_rules! define_external_test { dinputs: *const $ty, _data: *const $ty, ddata: *mut $ty, + _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -378,6 +368,7 @@ macro_rules! define_external_test { dinputs: *mut $ty, _data: *const $ty, ddata: *mut $ty, + _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -401,7 +392,7 @@ macro_rules! define_external_test { let mut data = vec![-1.0 as $ty; n_data]; let inputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs(&inputs, &mut data); + compiler.set_inputs(&inputs, &mut data, 0); let mut inputs_out = vec![-2.0 as $ty; n_inputs]; compiler.get_inputs(&mut inputs_out, &data); @@ -412,15 +403,15 @@ macro_rules! define_external_test { assert_eq!(u[0], 1.0 as $ty); let mut out = vec![-3.0 as $ty; n_outputs]; - compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out, 0); + compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out); assert_eq!(out[0], u[0]); let mut rr = vec![-4.0 as $ty; n_states]; - compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr, 0); + compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr); assert_eq!(rr[0], 0.0 as $ty); let mut stop = vec![-5.0 as $ty; n_stop]; - compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop, 0); + compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop); assert_eq!(stop[0], 0.5 as $ty); let mut mv = vec![-6.0 as $ty; n_states]; @@ -434,23 +425,23 @@ macro_rules! define_external_test { let du = vec![1.0 as $ty; n_states]; let mut ddata = vec![-8.0 as $ty; n_data]; let mut drr = vec![-9.0 as $ty; n_states]; - compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr, 0); + compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr); assert_eq!(drr[0], -1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); let mut dout = vec![-10.0 as $ty; n_outputs]; - compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout, 0); + compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout); assert_eq!(dout[0], 1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); let mut dinputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata); + compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata, 0); assert_eq!(ddata[0], 1.0 as $ty); let mut du_rev = vec![-11.0 as $ty; n_states]; let mut ddata_rev = vec![-12.0 as $ty; n_data]; let mut drr_rev = vec![1.0 as $ty; n_states]; - compiler.rhs_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &rr, &mut drr_rev,, 0); + compiler.rhs_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &rr, &mut drr_rev); assert_eq!(du_rev[0], -12.0 as $ty); assert_eq!(ddata_rev[0], -12.0 as $ty); @@ -460,30 +451,30 @@ macro_rules! define_external_test { assert_eq!(dv[0], -12.0 as $ty); let mut dout_rev = vec![1.0 as $ty; n_outputs]; - compiler.calc_out_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &out, &mut dout_rev,, 0); + compiler.calc_out_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &out, &mut dout_rev); assert_eq!(du_rev[0], -11.0 as $ty); - compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev); + compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev, 0); assert_eq!(dinputs[0], -11.0 as $ty); let mut ddata_s = vec![-14.0 as $ty; n_data]; let mut drr_s = vec![-15.0 as $ty; n_states]; - compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s, 0); + compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s); assert_eq!(drr_s[0], 0.0 as $ty); assert_eq!(ddata_s[0], 0.0 as $ty); let mut dout_s = vec![-16.0 as $ty; n_outputs]; - compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s, 0); + compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s); assert_eq!(dout_s[0], 0.0 as $ty); let mut ddata_sr = vec![-17.0 as $ty; n_data]; let mut drr_sr = vec![-18.0 as $ty; n_states]; - compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr, 0); + compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr); assert_eq!(drr_sr[0], 0.0 as $ty); assert_eq!(ddata_sr[0], 0.0 as $ty); let mut dout_sr = vec![-19.0 as $ty; n_outputs]; - compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr, 0); + compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr); assert_eq!(dout_sr[0], 0.0 as $ty); } };