Skip to content

Commit 7662183

Browse files
committed
&[usize] -> TVec<usize> all the shapes
1 parent 9068b53 commit 7662183

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+325
-342
lines changed

api/rs/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ impl ValueInterface for Value {
364364
let dt = to_internal_dt(dt);
365365
let len = shape.iter().product::<usize>() * dt.size_of();
366366
anyhow::ensure!(len == data.len());
367-
let tensor = unsafe { Tensor::from_raw_dt(dt, shape, data)? };
367+
let tensor = unsafe { Tensor::from_raw_dt(dt, shape.into(), data)? };
368368
Ok(Value(tensor.into_tvalue()))
369369
}
370370

core/src/ops/array/broadcast.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl EvalOp for MultiBroadcastTo {
2424
inputs: TVec<TValue>,
2525
) -> TractResult<TVec<TValue>> {
2626
let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
27-
Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
27+
Ok(tvec!(inputs[0].broadcast_to_shape(shape.into_owned())?.into_tvalue()))
2828
}
2929
}
3030

core/src/ops/array/gather.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ impl Gather {
2929
unsafe fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<TValue> {
3030
let data_view = data.to_array_view_unchecked::<T>();
3131
let indices = indices.to_array_view::<i64>()?;
32-
let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
32+
let output_shape = self.compute_output_shape(data.shape(), indices.shape())?;
3333
let mut output = Tensor::uninitialized::<T>(output_shape)?;
3434
let mut output_view = output.to_array_view_mut::<T>()?;
35-
for coords in tract_ndarray::indices(output_shape) {
35+
for coords in tract_ndarray::indices(output_view.shape()) {
3636
let ocoords = coords.as_array_view();
3737
let ocoords = ocoords.as_slice().unwrap();
3838
let mut icoords: TVec<usize> = ocoords[0..self.axis].into();

core/src/ops/array/gather_nd.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl EvalOp for GatherNd {
9191
let indices = indices.cast_to::<i32>()?;
9292
let indices = indices.to_array_view::<i32>()?;
9393
unsafe {
94-
let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
94+
let mut output = Tensor::uninitialized_dt(data.datum_type(), shape)?;
9595
dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
9696
self,
9797
&mut output,

core/src/ops/array/one_hot.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ impl EvalOp for OneHot {
5757
let mut shape: TVec<usize> = input.shape().into();
5858
shape.insert(self.axis, self.dim);
5959
unsafe {
60-
let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
60+
let mut output = self.off.broadcast_scalar_to_shape(shape)?;
6161
dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
6262
self,
6363
&input,

core/src/ops/array/range.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl Range {
3737
len: usize,
3838
) -> TractResult<Tensor> {
3939
unsafe {
40-
let mut result = Tensor::uninitialized::<T>(&[len])?;
40+
let mut result = Tensor::uninitialized::<T>(tvec!(len))?;
4141
let mut v = start.to_scalar::<T>()?.clone();
4242
let step = step.to_scalar::<T>()?;
4343
for i in 0..len {

core/src/ops/array/reshape.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ impl Op for FiniteReshape {
1818
op_as_typed_op!();
1919
}
2020

21-
22-
2321
impl EvalOp for FiniteReshape {
2422
fn is_stateless(&self) -> bool {
2523
true
@@ -29,7 +27,7 @@ impl EvalOp for FiniteReshape {
2927
let input = args_1!(inputs);
3028
let mut tensor = input.into_tensor();
3129
unsafe {
32-
tensor.set_shape_unchecked(&self.shape);
30+
tensor.set_shape_unchecked(self.shape.clone());
3331
}
3432
Ok(tvec!(tensor.into_tvalue()))
3533
}

core/src/ops/array/slice.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractRes
8686
unsafe {
8787
let mut shape: TVec<_> = input.shape().into();
8888
shape[axis] = end - start;
89-
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
89+
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), shape)?;
9090
tensor.assign_slice_unchecked(.., input, start..end, axis);
9191
Ok(tvec!(tensor.into_tvalue()))
9292
}

core/src/ops/array/topk.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ impl EvalOp for Topk {
3131
let k = k.cast_to_scalar::<i64>()? as usize;
3232
output_shape[self.axis] = k;
3333
let dt = input.datum_type();
34-
let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35-
let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36-
let mut iterating_shape = output_shape.clone();
34+
let mut output_values = Tensor::zero_dt(dt, output_shape.clone())?;
35+
let mut output_indices = Tensor::zero::<i64>(output_shape.clone())?;
36+
let mut iterating_shape = output_shape;
3737
iterating_shape[self.axis] = 1;
3838
let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
3939
for coords in tract_ndarray::indices(&*iterating_shape) {

core/src/ops/binary.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static +
101101
self.eval_in_a(&mut a, &b)?;
102102
Ok(a)
103103
} else {
104-
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
104+
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, c_shape)? };
105105
self.eval_out_of_place(&mut c, &a, &b)?;
106106
Ok(c)
107107
}
@@ -584,7 +584,7 @@ macro_rules! bin_to_super_type {
584584
let b = b.to_array_view::<u8>()?;
585585
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
586586
.context("no broadcast solution")?;
587-
let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
587+
let mut c = Tensor::zero_dt(*c_dt, c_shape)?;
588588
let view = c.to_array_view_mut::<u8>()?;
589589
$crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
590590
*c = (scale_by($q_op_on_f32(
@@ -613,7 +613,7 @@ macro_rules! bin_to_super_type {
613613
let b = b.cast_to_dt(accumulator_dt)?.into_owned();
614614
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
615615
.context("no broadcast solution")?;
616-
let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
616+
let mut c = Tensor::zero_dt(accumulator_dt, c_shape)?;
617617
match accumulator_dt {
618618
DatumType::F32 => {
619619
let view = c.to_array_view_mut::<f32>()?;

0 commit comments

Comments
 (0)