Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffsl/benches/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn execute<const N: usize, M: CodegenModuleCompile + CodegenModuleJit>(
let n = N;
let compiler = setup::<M>(n, f_text, "execute");
let mut data = compiler.get_new_data();
compiler.set_inputs(&[], data.as_mut_slice());
compiler.set_inputs(&[], data.as_mut_slice(), 0);
let mut u = vec![1.0; n];
compiler.set_u0(u.as_mut_slice(), data.as_mut_slice());
let mut rr = vec![0.0; n];
Expand Down
19 changes: 14 additions & 5 deletions diffsl/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,18 @@ impl<'a> Ast<'a> {
}
AstKind::CallArg(arg) => Self::new_call_arg(arg.name, arg.expression.tangent()),
AstKind::Name(name) => {
if name.name == "t" {
if name.name == "t" || name.name == "N" {
Self::new_number(0.0)
} else {
Self::new_name(name.name, name.indices.clone(), true)
Ast {
kind: AstKind::Name(Name {
name: name.name,
indices: name.indices.clone(),
indice: name.indice.clone(),
is_tangent: true,
}),
span: self.span,
}
}
}
AstKind::Number(_) => Self::new_number(0.0),
Expand Down Expand Up @@ -883,9 +891,10 @@ impl<'a> Ast<'a> {
AstKind::Name(found_name) => {
// if the name is indexed by a single indice, don't add that index to the list
if let Some(indice) = &found_name.indice {
let indice = indice.as_ref().kind.as_indice().unwrap();
if indice.sep.is_none() || indice.last.is_none() {
return;
if let Some(indice) = indice.as_ref().kind.as_indice() {
if indice.sep.is_none() || indice.last.is_none() {
return;
}
}
}
indices.extend(found_name.indices.iter().cloned());
Expand Down
11 changes: 10 additions & 1 deletion diffsl/src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ impl<'s> DiscreteModel<'s> {
let reserved_names = [
"u0",
"t",
"N",
"data",
"root",
"thread_id",
Expand Down Expand Up @@ -502,7 +503,11 @@ impl<'s> DiscreteModel<'s> {
let dependent_on_time = env_entry.is_time_dependent();
let dependent_on_dudt = env_entry.is_dstatedt_dependent();
let dependent_on_input = env_entry.is_input_dependent();
if !dependent_on_time && !dependent_on_input {
let dependent_on_model = env_entry.is_model_dependent();
if !dependent_on_time
&& !dependent_on_input
&& !dependent_on_model
{
ret.constant_defns.push(built);
} else if !dependent_on_time {
ret.input_dep_defns.push(built);
Expand Down Expand Up @@ -1401,6 +1406,7 @@ mod tests {

tensor_fail_tests!(
error_scalar: "r {1, 2}" errors ["cannot have more than one element in a scalar",],
error_reserved_name_n: "N { 1 }" errors ["N is a reserved name",],
error_cannot_find: "r { k }" errors ["cannot find variable k",],
error_different_shape: "a_i { 1, 2 } b_i { 1, 2, 3 } c_i { a_i + b_i }" errors ["cannot broadcast shapes: [2], [3]",],
too_many_indices: "A_i { 1, 2 } B_i { (0:2): A_ij }" errors ["too many permutation indices",],
Expand Down Expand Up @@ -1447,6 +1453,9 @@ mod tests {
index: "A_i { 0.0, 1.0, 2.0 } B { A_i[1] }" expect "B" = "B { (): A_i[1] }",
index2: "A_i { 0.0, 1.0, 2.0 } B_i { A_i[1:3] }" expect "B" = "B_i(2) { (0)(2):A_i[1:3](2) }",
index3: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[0:2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[0:2] (2, 2) }",
index_n_scalar: "A_i { 0.0, 1.0, 2.0 } B { A_i[N % 2] }" expect "B" = "B { (): A_i[N % 2] }",
index_n_range: "A_i { 0.0, 1.0, 2.0, 3.0 } B_i { A_i[N:N+2] }" expect "B" = "B_i(2) { (0)(2):A_i[N:N + 2](2) }",
index_n_range_contraction: "A_ij { (0:2, 0:2): 1 } g_i { 0, 1, 2 } b_i { A_ij * g_j[N:N+2] }" expect "b" = "b_i (2) { (0)(2): A_ij * g_j[N:N + 2] (2, 2) }",
prefix_minus: "A { 1.0 / -2.0 }" expect "A" = "A { (): 1 / (-2) }",
time: "A_i { t }" expect "A" = "A_i (1) { (0)(1): t }",
named_blk: "A_i { (0:3): y = 1, 2 }" expect "A" = "A_i (4) { (0)(3): y = 1, (3)(1): 2 }",
Expand Down
223 changes: 212 additions & 11 deletions diffsl/src/discretise/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct EnvVar {
is_state_dependent: bool,
is_dstatedt_dependent: bool,
is_input_dependent: bool,
is_model_dependent: bool,
is_algebraic: bool,
}

Expand All @@ -43,6 +44,10 @@ impl EnvVar {
self.is_input_dependent
}

pub fn is_model_dependent(&self) -> bool {
self.is_model_dependent
}

pub fn layout(&self) -> &Layout {
self.layout.as_ref()
}
Expand All @@ -57,6 +62,127 @@ pub struct Env {
}

impl Env {
fn eval_const_integer_expr(expr: &Ast) -> Option<i64> {
match &expr.kind {
AstKind::Integer(v) => Some(*v),
AstKind::Number(v) => {
if v.fract() == 0.0 {
Some(*v as i64)
} else {
None
}
}
AstKind::Monop(op) => {
let child = Self::eval_const_integer_expr(op.child.as_ref())?;
match op.op {
'+' => Some(child),
'-' => child.checked_neg(),
_ => None,
}
}
AstKind::Binop(op) => {
let left = Self::eval_const_integer_expr(op.left.as_ref())?;
let right = Self::eval_const_integer_expr(op.right.as_ref())?;
match op.op {
'+' => left.checked_add(right),
'-' => left.checked_sub(right),
'*' => left.checked_mul(right),
'/' => {
if right == 0 {
None
} else {
Some(left / right)
}
}
'%' => {
if right == 0 {
None
} else {
Some(left % right)
}
}
_ => None,
}
}
_ => None,
}
}

fn eval_integer_expr_with_n(expr: &Ast, n: i64) -> Option<i64> {
match &expr.kind {
AstKind::Integer(v) => Some(*v),
AstKind::Number(v) => {
if v.fract() == 0.0 {
Some(*v as i64)
} else {
None
}
}
AstKind::Name(name) => {
if name.name == "N" {
Some(n)
} else {
None
}
}
AstKind::Monop(op) => {
let child = Self::eval_integer_expr_with_n(op.child.as_ref(), n)?;
match op.op {
'+' => Some(child),
'-' => child.checked_neg(),
_ => None,
}
}
AstKind::Binop(op) => {
let left = Self::eval_integer_expr_with_n(op.left.as_ref(), n)?;
let right = Self::eval_integer_expr_with_n(op.right.as_ref(), n)?;
match op.op {
'+' => left.checked_add(right),
'-' => left.checked_sub(right),
'*' => left.checked_mul(right),
'/' => {
if right == 0 {
None
} else {
Some(left / right)
}
}
'%' => {
if right == 0 {
None
} else {
Some(left % right)
}
}
_ => None,
}
}
_ => None,
}
}

fn eval_constant_range_width(start: &Ast, end: &Ast) -> Option<i64> {
if let (Some(first), Some(last)) = (
Self::eval_const_integer_expr(start),
Self::eval_const_integer_expr(end),
) {
return last.checked_sub(first);
}

let mut width = None;
for n in [0_i64, 1, 2, 3, 7, 16] {
let first_n = Self::eval_integer_expr_with_n(start, n)?;
let last_n = Self::eval_integer_expr_with_n(end, n)?;
let width_n = last_n.checked_sub(first_n)?;
match width {
Some(prev) if prev != width_n => return None,
Some(_) => {}
None => width = Some(width_n),
}
}
width
}

pub fn new() -> Self {
let mut vars = HashMap::new();
vars.insert(
Expand All @@ -67,6 +193,19 @@ impl Env {
is_state_dependent: false,
is_dstatedt_dependent: false,
is_input_dependent: false,
is_model_dependent: false,
is_algebraic: true,
},
);
vars.insert(
"N".to_string(),
EnvVar {
layout: ArcLayout::new(Layout::new_scalar()),
is_time_dependent: false,
is_state_dependent: false,
is_dstatedt_dependent: false,
is_input_dependent: false,
is_model_dependent: true,
is_algebraic: true,
},
);
Expand Down Expand Up @@ -109,6 +248,19 @@ impl Env {
self.is_tensor_dependent_on(tensor, "in")
}

pub fn is_tensor_model_dependent(&self, tensor: &Tensor) -> bool {
if tensor.name() == "N" {
return true;
}
tensor.elmts().iter().any(|block| {
block
.expr()
.get_dependents()
.iter()
.any(|&dep| dep == "N" || self.vars[dep].is_model_dependent())
})
}

pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool {
self.is_tensor_dependent_on(tensor, "dudt")
}
Expand Down Expand Up @@ -141,6 +293,7 @@ impl Env {
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
is_model_dependent: self.is_tensor_model_dependent(var),
},
);
}
Expand All @@ -155,6 +308,7 @@ impl Env {
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
is_model_dependent: self.is_tensor_model_dependent(var),
},
);
}
Expand Down Expand Up @@ -280,32 +434,79 @@ impl Env {
// if the indice is a single integer then the resulting layout is a scalar
if indice.sep.is_none() {
let mut new_layout = Layout::new_scalar();
let first = indice.first.kind.as_integer().unwrap();
let last = first + 1;
let (first, last) =
if let Some(first) = Self::eval_const_integer_expr(indice.first.as_ref()) {
(first, first + 1)
} else {
// Dynamic integer indices (e.g. involving N) cannot be resolved at compile time.
// Conservatively keep dependencies for the full 1D extent.
let axis = layout_permuted
.shape()
.iter()
.position(|&d| d != 1)
.unwrap_or(0);
let dim = *layout_permuted.shape().get(axis).unwrap_or(&1);
(0, i64::try_from(dim).unwrap())
};
new_layout.filter_deps_from(layout_permuted, first, last);
return Some(new_layout);
} else {
// if the indice is a range then the resulting layout is a dense layout with shape given by the range
// along the only non-unit dimension of the variable
let first = indice.first.kind.as_integer().unwrap();
let last = indice.last.as_ref().unwrap().kind.as_integer().unwrap();
let end_expr = indice.last.as_ref().unwrap().as_ref();
let Some(width) = Self::eval_constant_range_width(indice.first.as_ref(), end_expr)
else {
self.errs.push(ValidationError::new(
"range indice width must be an integer constant (independent of N)"
.to_string(),
ast.span,
));
return None;
};
// make sure the range is valid
if last < first {
if width < 0 {
self.errs.push(ValidationError::new(
format!(
"invalid range indice: start {} is greater than end {}",
first, last
),
format!("invalid range indice: width {} is negative", width),
ast.span,
));
return None;
}
let dim = usize::try_from(last - first).unwrap();
let dim = usize::try_from(width).unwrap();
let shape = layout_permuted
.shape()
.map(|&d| if d != 1 { dim } else { 1 });
let mut new_layout = Layout::new_dense(Shape::from(shape));
new_layout.filter_deps_from(layout_permuted, first, last);
let first_const = Self::eval_const_integer_expr(indice.first.as_ref());
if let Some(first) = first_const {
let last = first + width;
new_layout.filter_deps_from(layout_permuted, first, last);
} else if dim != 0 {
// For dynamic starts (e.g. N:N+2), union dependencies over all valid windows.
let axis = layout_permuted
.shape()
.iter()
.position(|&d| d != 1)
.unwrap_or(0);
let source_dim =
i64::try_from(*layout_permuted.shape().get(axis).unwrap_or(&1)).unwrap();
let max_start = source_dim.saturating_sub(width);
let mut merged_layout: Option<Layout> = None;
for start in 0..=max_start {
let mut window_layout = Layout::new_dense(new_layout.shape().clone());
window_layout.filter_deps_from(
layout_permuted.clone(),
start,
start + width,
);
merged_layout = Some(match merged_layout {
Some(accumulated) => accumulated.union(window_layout),
None => window_layout,
});
}
if let Some(merged_layout) = merged_layout {
new_layout = merged_layout;
}
}
return Some(new_layout);
}
}
Expand Down
Loading
Loading