Skip to content

Commit 711c01b

Browse files
committed
wip simplifying scan (cumsum test borken)
1 parent 3e91496 commit 711c01b

File tree

4 files changed

+68
-20
lines changed

4 files changed

+68
-20
lines changed

core/src/ops/array/dyn_slice.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub struct DynSlice {
66
pub axis: usize,
77
pub start_input: bool,
88
pub end_input: bool,
9-
pub symbol: Symbol,
9+
pub len: TDim,
1010
}
1111

1212
impl DynSlice {
@@ -63,8 +63,8 @@ impl EvalOp for DynSlice {
6363

6464
impl TypedOp for DynSlice {
6565
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
66-
let mut fact = inputs[0].clone();
67-
fact.shape.set(self.axis, self.symbol.clone().into());
66+
let mut fact = inputs[0].without_value();
67+
fact.shape.set(self.axis, self.len.clone().into());
6868
Ok(tvec!(fact))
6969
}
7070

hir/src/ops/array/strided_slice.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,10 @@ impl Expansion for StridedSlice {
297297
AxisOp::Rm(0),
298298
&right,
299299
)?[0];
300-
let sym = target.symbol_table.new_with_prefix("l");
300+
let len = target.symbol_table.new_with_prefix("len").to_dim();
301301
wire = target.wire_node(
302302
format!("{prefix}.slice-axis-{axis}"),
303-
tract_core::ops::array::DynSlice::new(axis, true, true, sym),
303+
tract_core::ops::array::DynSlice::new(axis, true, true, len),
304304
&[wire, left, right],
305305
)?[0];
306306
}

onnx/src/ops/cumsum.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use tract_hir::internal::*;
2+
use tract_hir::tract_core::ops::array::DynSlice;
23
use tract_hir::tract_core::ops::scan::ScanInfo;
34

45
use crate::model::{OnnxOpRegister, ParsingContext};
@@ -54,11 +55,17 @@ impl Expansion for CumSum {
5455
)?[0];
5556
let chunk = if self.reverse { -1 } else { 1 };
5657
let input_mapping =
57-
vec![scan::InputMapping::Scan(ScanInfo { axis, chunk }), scan::InputMapping::State];
58+
vec![scan::InputMapping::Full, scan::InputMapping::State, scan::InputMapping::State];
5859
// outputs will be
5960
// acc + x (!exclusive)
6061
// acc input (exclusive)
6162
let output_mapping = vec![
63+
scan::OutputMapping {
64+
scan: None,
65+
full_dim_hint: None,
66+
last_value_slot: None,
67+
state: true,
68+
},
6269
scan::OutputMapping {
6370
scan: Some((0, ScanInfo { axis, chunk })),
6471
full_dim_hint: None,
@@ -74,12 +81,32 @@ impl Expansion for CumSum {
7481
];
7582
let mut body = TypedModel::default();
7683
let var_fact = data.datum_type.fact(var_shape);
77-
let x = body.add_source("scan_input", var_fact.clone())?;
84+
let x = body.add_source("scan_input", data)?;
85+
86+
let i = body.add_source("i", i64::scalar_fact())?;
87+
let one = body.add_const("one", tensor0(1i64))?;
88+
let i_plus_one = body.wire_node("inc_i", tract_core::ops::math::add(), &[i, one])?[0];
89+
let x_slice = body.wire_node(
90+
"x",
91+
DynSlice {
92+
axis,
93+
start_input: true,
94+
end_input: true,
95+
len: 1.to_dim(),
96+
},
97+
&[x, i, i_plus_one],
98+
)?[0];
99+
78100
let acc = body.add_source("acc_input", var_fact)?;
79-
let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0];
80-
body.set_output_outlets(&[sum, acc])?;
101+
dbg!(axis);
102+
dbg!(body.outlet_fact(x));
103+
dbg!(body.outlet_fact(x_slice));
104+
dbg!(body.outlet_fact(acc));
105+
let sum = body.wire_node("add", tract_core::ops::math::add(), &[x_slice, acc])?[0];
106+
body.set_output_outlets(&[i_plus_one, sum, acc])?;
81107
let scan = scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?;
82-
let wires = model.wire_node(prefix, scan, &[inputs[0], init])?;
108+
let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i64))?;
109+
let wires = model.wire_node(prefix, scan, &[inputs[0], zero, init])?;
83110
let output = wires[self.exclusive as usize];
84111
Ok(tvec![output])
85112
}

onnx/src/ops/rec/common.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::fmt::Debug;
33
use crate::pb::*;
44
use tract_hir::internal::*;
55
use tract_hir::tract_core::dyn_clone::{clone_trait_object, DynClone};
6+
use tract_hir::tract_core::ops::array::DynSlice;
67
use tract_hir::tract_core::ops::scan::ScanInfo;
78

89
pub trait WireBody: Debug + DynClone + Send + Sync {
@@ -117,12 +118,21 @@ impl CommonRec {
117118
// scann inner interface: [chunk=1, batch_size, input_size]
118119
// onnx inner interface: [batch_size, input_size]
119120
outer_inputs.push(x_batch_first);
120-
input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
121-
let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
121+
// input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
122+
input_mapping.push(scan::InputMapping::Full);
123+
let x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
122124
let iters = x_source_fact.shape[1].clone();
123-
x_source_fact.shape.set(1, 1.to_dim());
124125
let x_source = body.add_source("x_source", x_source_fact)?;
125-
wire!(Xt = AxisOp::Rm(1), x_source);
126+
127+
input_mapping.push(scan::InputMapping::State);
128+
let zero = target.add_const(format!("{prefix}.zero"), tensor0(0i64))?;
129+
outer_inputs.push(zero);
130+
let i = body.add_source("i", i64::scalar_fact())?;
131+
let one = body.add_const("one", tensor0(1i64))?;
132+
wire!(i_plus_one = tract_core::ops::math::add(), i, one);
133+
let dyn_slice = DynSlice { axis: 1, start_input: true, end_input: true, len: 1.to_dim() };
134+
wire!(x_slice = dyn_slice, x_source, i, i_plus_one);
135+
wire!(Xt = AxisOp::Rm(1), x_slice);
126136

127137
// W: onnx interface: [num_directions, 3*hidden_size, input_size]
128138
// scan interfaces: [3*hidden_size, input_size]
@@ -229,13 +239,24 @@ impl CommonRec {
229239
};
230240

231241
self.body.wire_body(prefix, &mut body).context("Wiring body")?;
242+
let mut outputs = body.outputs.clone();
243+
outputs.insert(0, i_plus_one);
244+
body.set_output_outlets(&*outputs)?;
232245

233-
let mut output_mapping = vec![scan::OutputMapping {
234-
state: true,
235-
full_dim_hint: None,
236-
last_value_slot: self.optional_y_h_output,
237-
scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
238-
}];
246+
let mut output_mapping = vec![
247+
scan::OutputMapping {
248+
state: true,
249+
full_dim_hint: None,
250+
last_value_slot: None,
251+
scan: None,
252+
},
253+
scan::OutputMapping {
254+
state: true,
255+
full_dim_hint: None,
256+
last_value_slot: self.optional_y_h_output,
257+
scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
258+
},
259+
];
239260
if self.body.have_extra_c_state() {
240261
output_mapping.push(scan::OutputMapping {
241262
state: true,

0 commit comments

Comments
 (0)