@@ -3,6 +3,7 @@ use std::fmt::Debug;
33use crate :: pb:: * ;
44use tract_hir:: internal:: * ;
55use tract_hir:: tract_core:: dyn_clone:: { clone_trait_object, DynClone } ;
6+ use tract_hir:: tract_core:: ops:: array:: DynSlice ;
67use tract_hir:: tract_core:: ops:: scan:: ScanInfo ;
78
89pub trait WireBody : Debug + DynClone + Send + Sync {
@@ -117,12 +118,21 @@ impl CommonRec {
117118 // scann inner interface: [chunk=1, batch_size, input_size]
118119 // onnx inner interface: [batch_size, input_size]
119120 outer_inputs. push ( x_batch_first) ;
120- input_mapping. push ( scan:: InputMapping :: Scan ( ScanInfo { axis : 1 , chunk } ) ) ;
121- let mut x_source_fact = target. outlet_fact ( x_batch_first) ?. without_value ( ) ;
121+ // input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
122+ input_mapping. push ( scan:: InputMapping :: Full ) ;
123+ let x_source_fact = target. outlet_fact ( x_batch_first) ?. without_value ( ) ;
122124 let iters = x_source_fact. shape [ 1 ] . clone ( ) ;
123- x_source_fact. shape . set ( 1 , 1 . to_dim ( ) ) ;
124125 let x_source = body. add_source ( "x_source" , x_source_fact) ?;
125- wire ! ( Xt = AxisOp :: Rm ( 1 ) , x_source) ;
126+
127+ input_mapping. push ( scan:: InputMapping :: State ) ;
128+ let zero = target. add_const ( format ! ( "{prefix}.zero" ) , tensor0 ( 0i64 ) ) ?;
129+ outer_inputs. push ( zero) ;
130+ let i = body. add_source ( "i" , i64:: scalar_fact ( ) ) ?;
131+ let one = body. add_const ( "one" , tensor0 ( 1i64 ) ) ?;
132+ wire ! ( i_plus_one = tract_core:: ops:: math:: add( ) , i, one) ;
133+ let dyn_slice = DynSlice { axis : 1 , start_input : true , end_input : true , len : 1 . to_dim ( ) } ;
134+ wire ! ( x_slice = dyn_slice, x_source, i, i_plus_one) ;
135+ wire ! ( Xt = AxisOp :: Rm ( 1 ) , x_slice) ;
126136
127137 // W: onnx interface: [num_directions, 3*hidden_size, input_size]
128138 // scan interfaces: [3*hidden_size, input_size]
@@ -229,13 +239,24 @@ impl CommonRec {
229239 } ;
230240
231241 self . body . wire_body ( prefix, & mut body) . context ( "Wiring body" ) ?;
242+ let mut outputs = body. outputs . clone ( ) ;
243+ outputs. insert ( 0 , i_plus_one) ;
244+ body. set_output_outlets ( & * outputs) ?;
232245
233- let mut output_mapping = vec ! [ scan:: OutputMapping {
234- state: true ,
235- full_dim_hint: None ,
236- last_value_slot: self . optional_y_h_output,
237- scan: self . optional_y_output. map( |slot| ( slot, ScanInfo { axis: 1 , chunk } ) ) ,
238- } ] ;
246+ let mut output_mapping = vec ! [
247+ scan:: OutputMapping {
248+ state: true ,
249+ full_dim_hint: None ,
250+ last_value_slot: None ,
251+ scan: None ,
252+ } ,
253+ scan:: OutputMapping {
254+ state: true ,
255+ full_dim_hint: None ,
256+ last_value_slot: self . optional_y_h_output,
257+ scan: self . optional_y_output. map( |slot| ( slot, ScanInfo { axis: 1 , chunk } ) ) ,
258+ } ,
259+ ] ;
239260 if self . body . have_extra_c_state ( ) {
240261 output_mapping. push ( scan:: OutputMapping {
241262 state : true ,
0 commit comments