Skip to content

Commit eb7524f

Browse files
committed
Tensordot implementation refinement
Updates ``tensordot_impl`` by removing one vector allocation, eliminating repeated axis-membership scans, and reducing shape-index lookups. Also switches axes_a/axes_b to borrowed inputs for potential reuse upstream. Replace ``notin_a`` + ``clone`` with direct construction of ``out_shape``, removing 1 allocation and 1 clone. Allocation count is now: - ``is_contracted_a`` → ``O(ndim(a))`` - is_contracted_b → ``O(ndim(b))`` - newaxes_a → ``O(ndim(a))`` - newaxes_b → ``O(ndim(b))`` - out_shape → ``O(ndim(a) + ndim(b) - contracted)`` All sizes are exactly determined by the axis mask and do not depend on runtime data beyond shape rank. Precompute boolean membership arrays for contracted axes, replacing multiple ``iter().any()`` scans with O(1) lookups. Cache ``a.shape()`` and ``b.shape()`` slices to avoid repeated indexing. Update signature to accept borrowed axis lists, allowing the caller to reuse them without moving.
1 parent 634c3d4 commit eb7524f

File tree

1 file changed

+118
-51
lines changed

1 file changed

+118
-51
lines changed

src/linalg/impl_linalg.rs

Lines changed: 118 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,14 +1252,14 @@ pub trait Tensordot<Rhs: ?Sized>
12521252
/// - The computed product of reshaped dimensions does not equal the
12531253
/// array’s total element count (which would indicate internal logic error).
12541254
#[track_caller]
1255-
fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output;
1255+
fn tensordot(&self, rhs: &Rhs, axes: &AxisSpec) -> Self::Output;
12561256
}
12571257

12581258
/// Perform a tensor contraction along specified axes.
12591259
///
12601260
/// See [`Tensordot::tensordot`] for more details.
12611261
#[track_caller]
1262-
pub fn tensordot<T, Sa, Sb, Da, Db>(a: &ArrayBase<Sa, Da>, b: &ArrayBase<Sb, Db>, axes: AxisSpec) -> ArrayD<T>
1262+
pub fn tensordot<T, Sa, Sb, Da, Db>(a: &ArrayBase<Sa, Da>, b: &ArrayBase<Sb, Db>, axes: &AxisSpec) -> ArrayD<T>
12631263
where
12641264
T: LinalgScalar,
12651265
Sa: Data<Elem = T>,
@@ -1272,7 +1272,11 @@ where
12721272

12731273
/// Performs the full contraction given resolved axis specification.
12741274
#[track_caller]
1275-
fn tensordot_impl<T, Sa, Sb, Da, Db>(a: &ArrayBase<Sa, Da>, b: &ArrayBase<Sb, Db>, axes: AxisSpec) -> ArrayD<T>
1275+
fn tensordot_impl<T, Sa, Sb, Da, Db>(
1276+
a: &ArrayBase<Sa, Da>,
1277+
b: &ArrayBase<Sb, Db>,
1278+
axes: &AxisSpec,
1279+
) -> ArrayD<T>
12761280
where
12771281
T: LinalgScalar,
12781282
Sa: Data<Elem = T>,
@@ -1283,18 +1287,32 @@ where
12831287
let nda = a.ndim() as isize;
12841288
let ndb = b.ndim() as isize;
12851289

1286-
// Resolve axes
1287-
let (mut axes_a, mut axes_b): (Vec<isize>, Vec<isize>) = match axes {
1288-
AxisSpec::Num(n) => {
1289-
let n = n as isize;
1290+
// Precompute shapes for reuse
1291+
let ashape = a.shape();
1292+
let bshape = b.shape();
1293+
1294+
// Resolve and normalise contracted axes (into owned Vecs, no cloning of input)
1295+
let (axes_a, axes_b): (Vec<isize>, Vec<isize>) = match axes {
1296+
AxisSpec::Num(n_raw) => {
1297+
let n = *n_raw as isize;
12901298
assert!(
12911299
n <= nda && n <= ndb,
12921300
"tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}",
12931301
n,
12941302
nda,
12951303
ndb
12961304
);
1297-
((nda - n)..nda).zip(0..n).map(|(ia, ib)| (ia, ib)).unzip()
1305+
1306+
let mut axes_a = Vec::with_capacity(n as usize);
1307+
let mut axes_b = Vec::with_capacity(n as usize);
1308+
1309+
// last n axes of a, first n axes of b
1310+
for i in 0..n {
1311+
axes_a.push(nda - n + i);
1312+
axes_b.push(i);
1313+
}
1314+
1315+
(axes_a, axes_b)
12981316
}
12991317
AxisSpec::Pair(aa, bb) => {
13001318
assert_eq!(
@@ -1304,23 +1322,27 @@ where
13041322
aa.len(),
13051323
bb.len()
13061324
);
1307-
(aa, bb)
1308-
}
1309-
};
13101325

1311-
// Normalise negative indices
1312-
for ax in &mut axes_a {
1313-
if *ax < 0 {
1314-
*ax += nda;
1315-
}
1316-
}
1317-
for ax in &mut axes_b {
1318-
if *ax < 0 {
1319-
*ax += ndb;
1326+
let mut axes_a = Vec::with_capacity(aa.len());
1327+
let mut axes_b = Vec::with_capacity(bb.len());
1328+
1329+
// Normalise negatives for a
1330+
for &ax in aa {
1331+
let ax_norm = if ax < 0 { ax + nda } else { ax };
1332+
axes_a.push(ax_norm);
1333+
}
1334+
1335+
// Normalise negatives for b
1336+
for &ax in bb {
1337+
let ax_norm = if ax < 0 { ax + ndb } else { ax };
1338+
axes_b.push(ax_norm);
1339+
}
1340+
1341+
(axes_a, axes_b)
13201342
}
1321-
}
1343+
};
13221344

1323-
// Validate
1345+
// Validate bounds
13241346
for &ax in &axes_a {
13251347
assert!(
13261348
(0..nda).contains(&ax),
@@ -1338,36 +1360,83 @@ where
13381360
);
13391361
}
13401362

1341-
// Shape checks
1363+
// Shape checks on contracted axes
13421364
for (ia, ib) in axes_a.iter().zip(&axes_b) {
1343-
let da = a.shape()[*ia as usize];
1344-
let db = b.shape()[*ib as usize];
1365+
let da = ashape[*ia as usize];
1366+
let db = bshape[*ib as usize];
13451367
assert_eq!(
13461368
da, db,
13471369
"tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}",
13481370
ia, da, ib, db
13491371
);
13501372
}
13511373

1352-
// Determine non-contracted axes
1353-
let notin_a: Vec<usize> = (0..nda as usize)
1354-
.filter(|k| !axes_a.iter().any(|&ax| ax as usize == *k))
1355-
.collect();
1356-
let notin_b: Vec<usize> = (0..ndb as usize)
1357-
.filter(|k| !axes_b.iter().any(|&ax| ax as usize == *k))
1358-
.collect();
1359-
1360-
// Reorder axes
1361-
let mut newaxes_a = notin_a.clone();
1362-
newaxes_a.extend(axes_a.iter().map(|&x| x as usize));
1363-
let mut newaxes_b = axes_b.iter().map(|&x| x as usize).collect::<Vec<_>>();
1364-
newaxes_b.extend(notin_b.iter().copied());
1365-
1366-
// Matrix shapes
1367-
let m = notin_a.iter().fold(1, |p, &ax| p * a.shape()[ax]);
1368-
let k = axes_a.iter().fold(1, |p, &ax| p * a.shape()[ax as usize]);
1369-
let n = notin_b.iter().fold(1, |p, &ax| p * b.shape()[ax]);
1374+
// Membership maps for contracted axes (O(ndim) setup, O(1) lookup)
1375+
let mut is_contracted_a = vec![false; nda as usize];
1376+
let mut is_contracted_b = vec![false; ndb as usize];
13701377

1378+
for &ax in &axes_a {
1379+
is_contracted_a[ax as usize] = true;
1380+
}
1381+
for &ax in &axes_b {
1382+
is_contracted_b[ax as usize] = true;
1383+
}
1384+
1385+
let contracted_a = axes_a.len();
1386+
let contracted_b = axes_b.len();
1387+
debug_assert_eq!(contracted_a, contracted_b);
1388+
let free_a = nda as usize - contracted_a;
1389+
let free_b = ndb as usize - contracted_b;
1390+
1391+
// Permutation axes for a: [non-contracted..., contracted...]
1392+
let mut newaxes_a = Vec::with_capacity(nda as usize);
1393+
for i in 0..nda as usize {
1394+
if !is_contracted_a[i] {
1395+
newaxes_a.push(i);
1396+
}
1397+
}
1398+
for &ax in &axes_a {
1399+
newaxes_a.push(ax as usize);
1400+
}
1401+
1402+
// non-contracted axes for b (indices)
1403+
let mut notin_b = Vec::with_capacity(free_b);
1404+
for i in 0..ndb as usize {
1405+
if !is_contracted_b[i] {
1406+
notin_b.push(i);
1407+
}
1408+
}
1409+
1410+
// Permutation axes for b: [contracted..., non-contracted...]
1411+
let mut newaxes_b = Vec::with_capacity(ndb as usize);
1412+
for &ax in &axes_b {
1413+
newaxes_b.push(ax as usize);
1414+
}
1415+
newaxes_b.extend(&notin_b);
1416+
1417+
// Output shape: a(non-contracted) ⧺ b(non-contracted)
1418+
let mut out_shape = Vec::with_capacity(free_a + free_b);
1419+
for i in 0..nda as usize {
1420+
if !is_contracted_a[i] {
1421+
out_shape.push(ashape[i]);
1422+
}
1423+
}
1424+
for &ax in &notin_b {
1425+
out_shape.push(bshape[ax]);
1426+
}
1427+
1428+
// Matrix dims
1429+
let m = newaxes_a[..free_a]
1430+
.iter()
1431+
.fold(1, |p, &ax| p * ashape[ax]);
1432+
let k = axes_a
1433+
.iter()
1434+
.fold(1, |p, &ax| p * ashape[ax as usize]);
1435+
let n = notin_b
1436+
.iter()
1437+
.fold(1, |p, &ax| p * bshape[ax]);
1438+
1439+
// Permute + standard layout (keep temporaries named to satisfy lifetimes)
13711440
let a_dyn = a.view().into_dimensionality::<IxDyn>().unwrap();
13721441
let b_dyn = b.view().into_dimensionality::<IxDyn>().unwrap();
13731442

@@ -1377,6 +1446,7 @@ where
13771446
let b_perm = b_dyn.permuted_axes(IxDyn(&newaxes_b));
13781447
let b_std = b_perm.as_standard_layout();
13791448

1449+
// Reshape to 2D, multiply, and reshape back
13801450
let a2 = a_std
13811451
.into_shape_with_order(Ix2(m, k))
13821452
.expect("reshaping a to 2D");
@@ -1386,9 +1456,6 @@ where
13861456

13871457
let c2 = a2.dot(&b2);
13881458

1389-
let mut out_shape: Vec<usize> = notin_a.iter().map(|&ax| a.shape()[ax]).collect();
1390-
out_shape.extend(notin_b.iter().map(|&ax| b.shape()[ax]));
1391-
13921459
c2.into_shape_with_order(IxDyn(&out_shape)).unwrap()
13931460
}
13941461

@@ -1404,7 +1471,7 @@ where
14041471
type Output = ArrayD<A>;
14051472

14061473
#[track_caller]
1407-
fn tensordot(&self, rhs: &ArrayBase<S2, D2>, axes: AxisSpec) -> Self::Output
1474+
fn tensordot(&self, rhs: &ArrayBase<S2, D2>, axes: &AxisSpec) -> Self::Output
14081475
{
14091476
tensordot_impl::<A, S, S2, D1, D2>(self, rhs, axes)
14101477
}
@@ -1421,7 +1488,7 @@ where
14211488
type Output = ArrayD<A>;
14221489

14231490
#[track_caller]
1424-
fn tensordot(&self, rhs: &ArrayRef<A, D2>, axes: AxisSpec) -> Self::Output
1491+
fn tensordot(&self, rhs: &ArrayRef<A, D2>, axes: &AxisSpec) -> Self::Output
14251492
{
14261493
let rhs_view: ArrayBase<ViewRepr<&A>, D2> = rhs.view();
14271494
tensordot_impl::<A, S, ViewRepr<&A>, D1, D2>(self, &rhs_view, axes)
@@ -1439,7 +1506,7 @@ where
14391506
type Output = ArrayD<A>;
14401507

14411508
#[track_caller]
1442-
fn tensordot(&self, rhs: &ArrayBase<S, D2>, axes: AxisSpec) -> Self::Output
1509+
fn tensordot(&self, rhs: &ArrayBase<S, D2>, axes: &AxisSpec) -> Self::Output
14431510
{
14441511
let self_view: ArrayBase<ViewRepr<&A>, D1> = self.view();
14451512
tensordot_impl::<A, ViewRepr<&A>, S, D1, D2>(&self_view, rhs, axes)
@@ -1459,7 +1526,7 @@ mod tensordot_tests
14591526
let a = ArrayD::from_shape_vec(IxDyn(&[3, 4, 5]), (0..60).collect::<Vec<_>>()).unwrap();
14601527
let b = ArrayD::from_shape_vec(IxDyn(&[4, 3, 2]), (0..24).collect::<Vec<_>>()).unwrap();
14611528

1462-
let c: ArrayD<i32> = tensordot(&a, &b, AxisSpec::Pair(vec![1, 0], vec![0, 1]));
1529+
let c: ArrayD<i32> = tensordot(&a, &b, &AxisSpec::Pair(vec![1, 0], vec![0, 1]));
14631530

14641531
// Expected shape: [5, 2]
14651532
assert_eq!(
@@ -1494,7 +1561,7 @@ mod tensordot_tests
14941561
let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![10, 20, 30, 40]).unwrap();
14951562

14961563
// Contract over 2 axes
1497-
let c: ArrayD<i32> = tensordot(&a, &b, AxisSpec::Num(2));
1564+
let c: ArrayD<i32> = tensordot(&a, &b, &AxisSpec::Num(2));
14981565

14991566
assert_eq!(
15001567
c.shape(),

0 commit comments

Comments
 (0)