Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions diffsl/src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub struct DiscreteModel<'s> {
state: Tensor<'s>,
state_dot: Option<Tensor<'s>>,
is_algebraic: Vec<bool>,
reset: Option<Tensor<'s>>,
stop: Option<Tensor<'s>>,
state0_input_deps: Vec<(usize, usize)>,
dstate0_input_deps: Vec<(usize, usize)>,
Expand Down Expand Up @@ -82,6 +83,9 @@ impl fmt::Display for DiscreteModel<'_> {
for defn in &self.state_dep_post_f_defns {
writeln!(f, "{defn}")?;
}
if let Some(reset) = &self.reset {
writeln!(f, "{reset}")?;
}
if let Some(stop) = &self.stop {
writeln!(f, "{stop}")?;
}
Expand Down Expand Up @@ -111,6 +115,7 @@ impl<'s> DiscreteModel<'s> {
state: Tensor::new_empty("u"),
state_dot: None,
is_algebraic: Vec::new(),
reset: None,
stop: None,
state0_input_deps: Vec::new(),
dstate0_input_deps: Vec::new(),
Expand Down Expand Up @@ -317,6 +322,7 @@ impl<'s> DiscreteModel<'s> {
let mut read_state = false;
let mut span_f = None;
let mut span_m = None;
let mut span_reset = None;
let mut seen_f = false;
for tensor_ast in model.tensors.iter() {
env.set_current_span(tensor_ast.span);
Expand Down Expand Up @@ -424,6 +430,29 @@ impl<'s> DiscreteModel<'s> {
}
}
}
"reset" => {
if let Some(built) =
Self::build_array(tensor, &mut env, TensorType::Other)
{
if !built.is_dense() {
env.errs_mut().push(ValidationError::new(
"reset must have a dense layout".to_string(),
span,
));
}
span_reset = Some(span);
ret.reset = Some(built);
}

if let Some(reset) = env.get("reset") {
if reset.is_dstatedt_dependent() {
env.errs_mut().push(ValidationError::new(
"reset must not be dependent on dudt".to_string(),
tensor_ast.span,
));
}
}
}
"stop" => {
if let Some(built) =
Self::build_array(tensor, &mut env, TensorType::Other)
Expand Down Expand Up @@ -581,6 +610,11 @@ impl<'s> DiscreteModel<'s> {
if let Some(span) = span_m {
Self::check_match(ret.lhs.as_ref().unwrap(), &ret.state, span, &mut env);
}
if let Some(span) = span_reset {
if let Some(reset) = ret.reset.as_ref() {
Self::check_match(reset, &ret.state, span, &mut env);
}
}

let map_dep = |deps: &Vec<NonZero>| -> Vec<(usize, usize)> {
deps.iter()
Expand Down Expand Up @@ -820,6 +854,7 @@ impl<'s> DiscreteModel<'s> {
let lhs = Tensor::new_no_layout("M", m_elmts, vec!['i']);
let rhs = Tensor::new_no_layout("F", f_elmts, vec!['i']);
let name = model.name;
let reset = None;
let stop = None;
let dstate_dep_defns = Vec::new();
let state_dep_post_f_defns = Vec::new();
Expand All @@ -838,6 +873,7 @@ impl<'s> DiscreteModel<'s> {
state_dep_post_f_defns,
dstate_dep_defns,
is_algebraic,
reset,
stop,
state0_input_deps: Vec::new(),
dstate0_input_deps: Vec::new(),
Expand Down Expand Up @@ -908,6 +944,10 @@ impl<'s> DiscreteModel<'s> {
self.stop.as_ref()
}

pub fn reset(&self) -> Option<&Tensor<'_>> {
self.reset.as_ref()
}

pub fn take_state0_input_deps(&mut self) -> Vec<(usize, usize)> {
std::mem::take(&mut self.state0_input_deps)
}
Expand Down Expand Up @@ -1635,6 +1675,53 @@ mod tests {
assert_eq!(model.stop().unwrap().elmts().len(), 1);
}

#[test]
fn test_reset_requires_same_shape_as_state() {
let text = "
u_i {
y = 1,
z = 2,
}
F_i {
y,
z,
}
reset_i {
2 * y,
}
";
let model_ds = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model_ds);
assert!(
model.is_err(),
"reset_i should be required to match the shape of u_i and F_i"
);
}

#[test]
fn test_reset_must_not_depend_on_dudt() {
let text = "
u_i {
y = 1,
}
dudt_i {
dydt = 0,
}
F_i {
y,
}
reset_i {
dydt,
}
";
let model_ds = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model_ds);
assert!(
model.is_err(),
"reset_i should be state-dependent and must not depend on dudt_i"
);
}

#[test]
fn test_no_out() {
let text = "
Expand Down
119 changes: 119 additions & 0 deletions diffsl/src/execution/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,26 @@ impl<M: CodegenModule, T: Scalar> Compiler<M, T> {
});
}

pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T]) {
if reset.is_empty() {
return;
}

self.check_state_len(yy, "yy");
self.check_state_len(reset, "reset");
self.check_data_len(data, "data");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.reset)(
t,
yy.as_ptr(),
data.as_ptr() as *mut T,
reset.as_ptr() as *mut T,
i,
dim,
)
});
}

pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(rr, "rr");
Expand Down Expand Up @@ -1016,6 +1036,8 @@ mod tests {
}

generate_tests!(test_stop);
generate_tests!(test_reset);
generate_tests!(test_reset_without_reset_tensor_is_noop);

#[allow(dead_code)]
fn test_stop<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
Expand Down Expand Up @@ -1068,6 +1090,103 @@ mod tests {
assert_eq!(stop.len(), 1);
}

#[allow(dead_code)]
fn test_reset<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
u_i {
y = 1,
z = 2,
}
dudt_i {
dydt = 0,
dzdt = 0,
}
M_i {
dydt,
dzdt,
}
F_i {
y,
z,
}
reset_i {
2 * y,
z + 10,
}
stop_i {
y - 0.5,
}
out_i {
y,
z,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();

let mut u0 = vec![T::zero(), T::zero()];
let mut reset = vec![T::zero(), T::zero()];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.reset(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
reset.as_mut_slice(),
);

assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap());
assert_relative_eq!(u0[1], T::from_f64(2.0).unwrap());
assert_relative_eq!(reset[0], T::from_f64(2.0).unwrap());
assert_relative_eq!(reset[1], T::from_f64(12.0).unwrap());
assert_eq!(reset.len(), 2);
}

#[allow(dead_code)]
fn test_reset_without_reset_tensor_is_noop<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
u_i {
y = 1,
}
F_i {
y,
}
out_i {
y,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();

let mut u0 = vec![T::zero()];
let mut reset: Vec<T> = vec![];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());

compiler.reset(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
reset.as_mut_slice(),
);
assert_eq!(reset.len(), 0);
}

generate_tests!(test_out_depends_on_internal_tensor);

generate_tests!(test_model_index_n_depends_on_model_index);
Expand Down
42 changes: 42 additions & 0 deletions diffsl/src/execution/cranelift/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ impl<M: Module> CraneliftModule<M> {

let set_u0 = ret.compile_set_u0(model)?;
let _calc_stop = ret.compile_calc_stop(model)?;
let _reset = ret.compile_reset(model)?;
let rhs = ret.compile_rhs(model)?;
let _mass = ret.compile_mass(model)?;
let calc_out = ret.compile_calc_out(model)?;
Expand Down Expand Up @@ -623,6 +624,47 @@ impl<M: Module> CraneliftModule<M> {
self.declare_function("calc_stop")
}

fn compile_reset(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "u", "data", "reset", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);

if let Some(reset) = model.reset() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}

for tensor in model.state_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}

let reset_ptr = *codegen.variables.get("reset").unwrap();
codegen.jit_compile_tensor(reset, Some(reset_ptr), false)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("reset")
}

fn compile_rhs(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
Expand Down
11 changes: 11 additions & 0 deletions diffsl/src/execution/external/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ macro_rules! define_symbol_module {
thread_id: UIntType,
thread_dim: UIntType,
);
#[link_name = "reset"]
pub fn reset(
time: $ty,
u: *const $ty,
data: *mut $ty,
reset: *mut $ty,
thread_id: UIntType,
thread_dim: UIntType,
);
#[link_name = "rhs"]
pub fn rhs(
time: $ty,
Expand Down Expand Up @@ -283,6 +292,7 @@ impl_extern_symbols!(f64, f64_symbols, {
"barrier_init" => barrier_init,
"set_constants" => set_constants,
"set_u0" => set_u0,
"reset" => reset,
"rhs" => rhs,
"rhs_grad" => rhs_grad,
"rhs_rgrad" => rhs_rgrad,
Expand Down Expand Up @@ -312,6 +322,7 @@ impl_extern_symbols!(f32, f32_symbols, {
"barrier_init" => barrier_init,
"set_constants" => set_constants,
"set_u0" => set_u0,
"reset" => reset,
"rhs" => rhs,
"rhs_grad" => rhs_grad,
"rhs_rgrad" => rhs_rgrad,
Expand Down
Loading
Loading