diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index a7a95c8..7dc9d3c 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -43,6 +43,7 @@ pub struct DiscreteModel<'s> { state: Tensor<'s>, state_dot: Option>, is_algebraic: Vec, + reset: Option>, stop: Option>, state0_input_deps: Vec<(usize, usize)>, dstate0_input_deps: Vec<(usize, usize)>, @@ -82,6 +83,9 @@ impl fmt::Display for DiscreteModel<'_> { for defn in &self.state_dep_post_f_defns { writeln!(f, "{defn}")?; } + if let Some(reset) = &self.reset { + writeln!(f, "{reset}")?; + } if let Some(stop) = &self.stop { writeln!(f, "{stop}")?; } @@ -111,6 +115,7 @@ impl<'s> DiscreteModel<'s> { state: Tensor::new_empty("u"), state_dot: None, is_algebraic: Vec::new(), + reset: None, stop: None, state0_input_deps: Vec::new(), dstate0_input_deps: Vec::new(), @@ -317,6 +322,7 @@ impl<'s> DiscreteModel<'s> { let mut read_state = false; let mut span_f = None; let mut span_m = None; + let mut span_reset = None; let mut seen_f = false; for tensor_ast in model.tensors.iter() { env.set_current_span(tensor_ast.span); @@ -424,6 +430,29 @@ impl<'s> DiscreteModel<'s> { } } } + "reset" => { + if let Some(built) = + Self::build_array(tensor, &mut env, TensorType::Other) + { + if !built.is_dense() { + env.errs_mut().push(ValidationError::new( + "reset must have a dense layout".to_string(), + span, + )); + } + span_reset = Some(span); + ret.reset = Some(built); + } + + if let Some(reset) = env.get("reset") { + if reset.is_dstatedt_dependent() { + env.errs_mut().push(ValidationError::new( + "reset must not be dependent on dudt".to_string(), + tensor_ast.span, + )); + } + } + } "stop" => { if let Some(built) = Self::build_array(tensor, &mut env, TensorType::Other) @@ -581,6 +610,11 @@ impl<'s> DiscreteModel<'s> { if let Some(span) = span_m { Self::check_match(ret.lhs.as_ref().unwrap(), &ret.state, span, &mut env); } + if let Some(span) = span_reset { + if let Some(reset) = ret.reset.as_ref() { + Self::check_match(reset, &ret.state, span, &mut env); + } + } let map_dep = |deps: &Vec| -> Vec<(usize, usize)> { deps.iter() @@ -820,6 +854,7 @@ impl<'s> DiscreteModel<'s> { let lhs = Tensor::new_no_layout("M", m_elmts, vec!['i']); let rhs = Tensor::new_no_layout("F", f_elmts, vec!['i']); let name = model.name; + let reset = None; let stop = None; let dstate_dep_defns = Vec::new(); let state_dep_post_f_defns = Vec::new(); @@ -838,6 +873,7 @@ impl<'s> DiscreteModel<'s> { state_dep_post_f_defns, dstate_dep_defns, is_algebraic, + reset, stop, state0_input_deps: Vec::new(), dstate0_input_deps: Vec::new(), @@ -908,6 +944,10 @@ impl<'s> DiscreteModel<'s> { self.stop.as_ref() } + pub fn reset(&self) -> Option<&Tensor<'_>> { + self.reset.as_ref() + } + pub fn take_state0_input_deps(&mut self) -> Vec<(usize, usize)> { std::mem::take(&mut self.state0_input_deps) } @@ -1635,6 +1675,53 @@ mod tests { assert_eq!(model.stop().unwrap().elmts().len(), 1); } + #[test] + fn test_reset_requires_same_shape_as_state() { + let text = " + u_i { + y = 1, + z = 2, + } + F_i { + y, + z, + } + reset_i { + 2 * y, + } + "; + let model_ds = parse_ds_string(text).unwrap(); + let model = DiscreteModel::build("$name", &model_ds); + assert!( + model.is_err(), + "reset_i should be required to match the shape of u_i and F_i" + ); + } + + #[test] + fn test_reset_must_not_depend_on_dudt() { + let text = " + u_i { + y = 1, + } + dudt_i { + dydt = 0, + } + F_i { + y, + } + reset_i { + dydt, + } + "; + let model_ds = parse_ds_string(text).unwrap(); + let model = DiscreteModel::build("$name", &model_ds); + assert!( + model.is_err(), + "reset_i should be state-dependent and must not depend on dudt_i" + ); + } + #[test] fn test_no_out() { let text = " diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index bf76dd0..e56d337 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -429,6 +429,26 @@ impl Compiler { }); } + pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T]) { + if reset.is_empty() { + return; + } + + self.check_state_len(yy, "yy"); + self.check_state_len(reset, "reset"); + self.check_data_len(data, "data"); + self.with_threading(|i, dim| unsafe { + (self.jit_functions.reset)( + t, + yy.as_ptr(), + data.as_ptr() as *mut T, + reset.as_ptr() as *mut T, + i, + dim, + ) + }); + } + pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); @@ -1016,6 +1036,8 @@ mod tests { } generate_tests!(test_stop); + generate_tests!(test_reset); + generate_tests!(test_reset_without_reset_tensor_is_noop); #[allow(dead_code)] fn test_stop() { @@ -1068,6 +1090,103 @@ mod tests { assert_eq!(stop.len(), 1); } + #[allow(dead_code)] + fn test_reset() { + let full_text = " + u_i { + y = 1, + z = 2, + } + dudt_i { + dydt = 0, + dzdt = 0, + } + M_i { + dydt, + dzdt, + } + F_i { + y, + z, + } + reset_i { + 2 * y, + z + 10, + } + stop_i { + y - 0.5, + } + out_i { + y, + z, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero(), T::zero()]; + let mut reset = vec![T::zero(), T::zero()]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.reset( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + reset.as_mut_slice(), + ); + + assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); + assert_relative_eq!(u0[1], T::from_f64(2.0).unwrap()); + assert_relative_eq!(reset[0], T::from_f64(2.0).unwrap()); + assert_relative_eq!(reset[1], T::from_f64(12.0).unwrap()); + assert_eq!(reset.len(), 2); + } + + #[allow(dead_code)] + fn test_reset_without_reset_tensor_is_noop< + M: CodegenModuleCompile + CodegenModuleJit, + T: Scalar + RelativeEq, + >() { + let full_text = " + u_i { + y = 1, + } + F_i { + y, + } + out_i { + y, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![T::zero()]; + let mut reset: Vec = vec![]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + + compiler.reset( + T::zero(), + u0.as_slice(), + data.as_mut_slice(), + reset.as_mut_slice(), + ); + assert_eq!(reset.len(), 0); + } + generate_tests!(test_out_depends_on_internal_tensor); generate_tests!(test_model_index_n_depends_on_model_index); diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index b0d3f02..ef0ac6a 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -477,6 +477,7 @@ impl CraneliftModule { let set_u0 = ret.compile_set_u0(model)?; let _calc_stop = ret.compile_calc_stop(model)?; + let _reset = ret.compile_reset(model)?; let rhs = ret.compile_rhs(model)?; let _mass = ret.compile_mass(model)?; let calc_out = ret.compile_calc_out(model)?; @@ -623,6 +624,47 @@ impl CraneliftModule { self.declare_function("calc_stop") } + fn compile_reset(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.int_type, + self.int_type, + ]; + let arg_names = &["t", "u", "data", "reset", "threadId", "threadDim"]; + { + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + if let Some(reset) = model.reset() { + let mut nbarrier = 0; + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + for tensor in model.state_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.state_dep_post_f_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + let reset_ptr = *codegen.variables.get("reset").unwrap(); + codegen.jit_compile_tensor(reset, Some(reset_ptr), false)?; + } + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + } + self.declare_function("reset") + } + fn compile_rhs(&mut self, model: &DiscreteModel) -> Result { let arg_types = &[ self.real_type, diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index 9a2ee0c..62b3041 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -24,6 +24,15 @@ macro_rules! define_symbol_module { thread_id: UIntType, thread_dim: UIntType, ); + #[link_name = "reset"] + pub fn reset( + time: $ty, + u: *const $ty, + data: *mut $ty, + reset: *mut $ty, + thread_id: UIntType, + thread_dim: UIntType, + ); #[link_name = "rhs"] pub fn rhs( time: $ty, @@ -283,6 +292,7 @@ impl_extern_symbols!(f64, f64_symbols, { "barrier_init" => barrier_init, "set_constants" => set_constants, "set_u0" => set_u0, + "reset" => reset, "rhs" => rhs, "rhs_grad" => rhs_grad, "rhs_rgrad" => rhs_rgrad, @@ -312,6 +322,7 @@ impl_extern_symbols!(f32, f32_symbols, { "barrier_init" => barrier_init, "set_constants" => set_constants, "set_u0" => set_u0, + "reset" => reset, "rhs" => rhs, "rhs_grad" => rhs_grad, "rhs_rgrad" => rhs_rgrad, diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index 9ae9eeb..0450f53 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -15,6 +15,14 @@ pub type StopFunc = unsafe extern "C" fn( thread_id: UIntType, thread_dim: UIntType, ); +pub type ResetFunc = unsafe extern "C" fn( + time: T, + u: *const T, + data: *mut T, + reset: *mut T, + thread_id: UIntType, + thread_dim: UIntType, +); pub type RhsFunc = unsafe extern "C" fn( time: T, u: *const T, @@ -193,6 +201,7 @@ pub type GetConstantFunc = pub(crate) struct JitFunctions { pub(crate) set_u0: U0Func, + pub(crate) reset: ResetFunc, pub(crate) rhs: RhsFunc, pub(crate) mass: MassFunc, pub(crate) calc_out: CalcOutFunc, @@ -211,6 +220,7 @@ impl JitFunctions { // check if all required symbols are present let required_symbols = [ "set_u0", + "reset", "rhs", "mass", "calc_out", @@ -227,6 +237,7 @@ impl JitFunctions { } } let set_u0 = unsafe { std::mem::transmute::<*const u8, U0Func>(symbol_map["set_u0"]) }; + let reset = unsafe { std::mem::transmute::<*const u8, ResetFunc>(symbol_map["reset"]) }; let rhs = unsafe { std::mem::transmute::<*const u8, RhsFunc>(symbol_map["rhs"]) }; let mass = unsafe { std::mem::transmute::<*const u8, MassFunc>(symbol_map["mass"]) }; let calc_out = @@ -250,6 +261,7 @@ impl JitFunctions { Ok(Self { set_u0, + reset, rhs, mass, calc_out, diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 650f197..aaad97d 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -290,6 +290,7 @@ impl CodegenModuleCompile for LlvmModule { let set_u0 = module.codegen_mut().compile_set_u0(model, code)?; let _calc_stop = module.codegen_mut().compile_calc_stop(model, code)?; + let _reset = module.codegen_mut().compile_reset(model, code)?; let rhs = module.codegen_mut().compile_rhs(model, false, code)?; let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?; let mass = module.codegen_mut().compile_mass(model, code)?; @@ -3342,6 +3343,88 @@ impl<'ctx> CodeGen<'ctx> { } } + pub fn compile_reset<'m>( + &mut self, + model: &'m DiscreteModel, + code: Option<&str>, + ) -> Result> { + let time_dep_fn = self.ensure_time_dep_fn(model, code)?; + let state_dep_fn = self.ensure_state_dep_fn(model, code)?; + let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; + self.clear(); + let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"]; + let function = self.add_function( + "reset", + fn_arg_names, + &[ + self.real_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.real_ptr_type.into(), + self.int_type.into(), + self.int_type.into(), + ], + None, + false, + ); + + let alias_id = Attribute::get_named_enum_kind_id("noalias"); + let noalign = self.context.create_enum_attribute(alias_id, 0); + for i in &[1, 2, 3] { + function.add_attribute(AttributeLoc::Param(*i), noalign); + } + + let _basic_block = self.start_function(function, code); + + for (i, arg) in function.get_param_iter().enumerate() { + let name = fn_arg_names[i]; + let alloca = self.function_arg_alloca(name, arg); + self.insert_param(name, alloca); + } + + self.insert_state(model.state()); + self.insert_data(model); + self.insert_indices(); + + if let Some(reset) = model.reset() { + let mut nbarriers = 0; + let total_barriers = (model.time_dep_defns().len() + + model.state_dep_defns().len() + + model.state_dep_post_f_defns().len() + + 1) as u64; + let total_barriers_val = self.int_type.const_int(total_barriers, false); + if !model.time_dep_defns().is_empty() { + self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; + nbarriers += model.time_dep_defns().len() as u64; + } + + if !model.state_dep_defns().is_empty() { + self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?; + nbarriers += model.state_dep_defns().len() as u64; + } + if !model.state_dep_post_f_defns().is_empty() { + self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?; + nbarriers += model.state_dep_post_f_defns().len() as u64; + } + + let res_ptr = self.get_param("reset"); + self.jit_compile_tensor(reset, Some(*res_ptr), code)?; + let barrier_num = self.int_type.const_int(nbarriers + 1, false); + self.jit_compile_call_barrier(barrier_num, total_barriers_val); + } + self.builder.build_return(None)?; + + if function.verify(true) { + Ok(function) + } else { + function.print_to_stderr(); + unsafe { + function.delete(); + } + Err(anyhow!("Invalid generated function.")) + } + } + pub fn compile_rhs<'m>( &mut self, model: &'m DiscreteModel, diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 0054f02..bde1184 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -296,6 +296,21 @@ macro_rules! define_external_test { *root = *u - (0.5 as $ty); } + #[no_mangle] + pub unsafe extern "C" fn reset( + _time: $ty, + u: *const $ty, + _data: *mut $ty, + reset: *mut $ty, + _thread_id: u32, + _thread_dim: u32, + ) { + if u.is_null() || reset.is_null() { + return; + } + *reset = (2.0 as $ty) * *u; + } + #[no_mangle] pub unsafe extern "C" fn set_id(id: *mut $ty) { if !id.is_null() { @@ -414,6 +429,10 @@ macro_rules! define_external_test { compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop); assert_eq!(stop[0], 0.5 as $ty); + let mut reset = vec![-5.5 as $ty; n_states]; + compiler.reset(0.0 as $ty, &u, &mut data, &mut reset); + assert_eq!(reset[0], 2.0 as $ty); + let mut mv = vec![-6.0 as $ty; n_states]; compiler.mass(0.0 as $ty, &u, &mut data, &mut mv); assert_eq!(mv[0], 1.0 as $ty);