@@ -1252,14 +1252,14 @@ pub trait Tensordot<Rhs: ?Sized>
12521252 /// - The computed product of reshaped dimensions does not equal the
12531253 /// array’s total element count (which would indicate internal logic error).
12541254 #[ track_caller]
1255- fn tensordot ( & self , rhs : & Rhs , axes : AxisSpec ) -> Self :: Output ;
1255+ fn tensordot ( & self , rhs : & Rhs , axes : & AxisSpec ) -> Self :: Output ;
12561256}
12571257
12581258/// Perform a tensor contraction along specified axes.
12591259///
12601260/// See [`Tensordot::tensordot`] for more details.
12611261#[ track_caller]
1262- pub fn tensordot < T , Sa , Sb , Da , Db > ( a : & ArrayBase < Sa , Da > , b : & ArrayBase < Sb , Db > , axes : AxisSpec ) -> ArrayD < T >
1262+ pub fn tensordot < T , Sa , Sb , Da , Db > ( a : & ArrayBase < Sa , Da > , b : & ArrayBase < Sb , Db > , axes : & AxisSpec ) -> ArrayD < T >
12631263where
12641264 T : LinalgScalar ,
12651265 Sa : Data < Elem = T > ,
@@ -1272,7 +1272,11 @@ where
12721272
12731273/// Performs the full contraction given resolved axis specification.
12741274#[ track_caller]
1275- fn tensordot_impl < T , Sa , Sb , Da , Db > ( a : & ArrayBase < Sa , Da > , b : & ArrayBase < Sb , Db > , axes : AxisSpec ) -> ArrayD < T >
1275+ fn tensordot_impl < T , Sa , Sb , Da , Db > (
1276+ a : & ArrayBase < Sa , Da > ,
1277+ b : & ArrayBase < Sb , Db > ,
1278+ axes : & AxisSpec ,
1279+ ) -> ArrayD < T >
12761280where
12771281 T : LinalgScalar ,
12781282 Sa : Data < Elem = T > ,
@@ -1283,18 +1287,32 @@ where
12831287 let nda = a. ndim ( ) as isize ;
12841288 let ndb = b. ndim ( ) as isize ;
12851289
1286- // Resolve axes
1287- let ( mut axes_a, mut axes_b) : ( Vec < isize > , Vec < isize > ) = match axes {
1288- AxisSpec :: Num ( n) => {
1289- let n = n as isize ;
1290+ // Precompute shapes for reuse
1291+ let ashape = a. shape ( ) ;
1292+ let bshape = b. shape ( ) ;
1293+
1294+ // Resolve and normalise contracted axes (into owned Vecs, no cloning of input)
1295+ let ( axes_a, axes_b) : ( Vec < isize > , Vec < isize > ) = match axes {
1296+ AxisSpec :: Num ( n_raw) => {
1297+ let n = * n_raw as isize ;
12901298 assert ! (
12911299 n <= nda && n <= ndb,
12921300 "tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}" ,
12931301 n,
12941302 nda,
12951303 ndb
12961304 ) ;
1297- ( ( nda - n) ..nda) . zip ( 0 ..n) . map ( |( ia, ib) | ( ia, ib) ) . unzip ( )
1305+
1306+ let mut axes_a = Vec :: with_capacity ( n as usize ) ;
1307+ let mut axes_b = Vec :: with_capacity ( n as usize ) ;
1308+
1309+ // last n axes of a, first n axes of b
1310+ for i in 0 ..n {
1311+ axes_a. push ( nda - n + i) ;
1312+ axes_b. push ( i) ;
1313+ }
1314+
1315+ ( axes_a, axes_b)
12981316 }
12991317 AxisSpec :: Pair ( aa, bb) => {
13001318 assert_eq ! (
@@ -1304,23 +1322,27 @@ where
13041322 aa. len( ) ,
13051323 bb. len( )
13061324 ) ;
1307- ( aa, bb)
1308- }
1309- } ;
13101325
1311- // Normalise negative indices
1312- for ax in & mut axes_a {
1313- if * ax < 0 {
1314- * ax += nda;
1315- }
1316- }
1317- for ax in & mut axes_b {
1318- if * ax < 0 {
1319- * ax += ndb;
1326+ let mut axes_a = Vec :: with_capacity ( aa. len ( ) ) ;
1327+ let mut axes_b = Vec :: with_capacity ( bb. len ( ) ) ;
1328+
1329+ // Normalise negatives for a
1330+ for & ax in aa {
1331+ let ax_norm = if ax < 0 { ax + nda } else { ax } ;
1332+ axes_a. push ( ax_norm) ;
1333+ }
1334+
1335+ // Normalise negatives for b
1336+ for & ax in bb {
1337+ let ax_norm = if ax < 0 { ax + ndb } else { ax } ;
1338+ axes_b. push ( ax_norm) ;
1339+ }
1340+
1341+ ( axes_a, axes_b)
13201342 }
1321- }
1343+ } ;
13221344
1323- // Validate
1345+ // Validate bounds
13241346 for & ax in & axes_a {
13251347 assert ! (
13261348 ( 0 ..nda) . contains( & ax) ,
@@ -1338,36 +1360,83 @@ where
13381360 ) ;
13391361 }
13401362
1341- // Shape checks
1363+ // Shape checks on contracted axes
13421364 for ( ia, ib) in axes_a. iter ( ) . zip ( & axes_b) {
1343- let da = a . shape ( ) [ * ia as usize ] ;
1344- let db = b . shape ( ) [ * ib as usize ] ;
1365+ let da = ashape [ * ia as usize ] ;
1366+ let db = bshape [ * ib as usize ] ;
13451367 assert_eq ! (
13461368 da, db,
13471369 "tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}" ,
13481370 ia, da, ib, db
13491371 ) ;
13501372 }
13511373
1352- // Determine non-contracted axes
1353- let notin_a: Vec < usize > = ( 0 ..nda as usize )
1354- . filter ( |k| !axes_a. iter ( ) . any ( |& ax| ax as usize == * k) )
1355- . collect ( ) ;
1356- let notin_b: Vec < usize > = ( 0 ..ndb as usize )
1357- . filter ( |k| !axes_b. iter ( ) . any ( |& ax| ax as usize == * k) )
1358- . collect ( ) ;
1359-
1360- // Reorder axes
1361- let mut newaxes_a = notin_a. clone ( ) ;
1362- newaxes_a. extend ( axes_a. iter ( ) . map ( |& x| x as usize ) ) ;
1363- let mut newaxes_b = axes_b. iter ( ) . map ( |& x| x as usize ) . collect :: < Vec < _ > > ( ) ;
1364- newaxes_b. extend ( notin_b. iter ( ) . copied ( ) ) ;
1365-
1366- // Matrix shapes
1367- let m = notin_a. iter ( ) . fold ( 1 , |p, & ax| p * a. shape ( ) [ ax] ) ;
1368- let k = axes_a. iter ( ) . fold ( 1 , |p, & ax| p * a. shape ( ) [ ax as usize ] ) ;
1369- let n = notin_b. iter ( ) . fold ( 1 , |p, & ax| p * b. shape ( ) [ ax] ) ;
1374+ // Membership maps for contracted axes (O(ndim) setup, O(1) lookup)
1375+ let mut is_contracted_a = vec ! [ false ; nda as usize ] ;
1376+ let mut is_contracted_b = vec ! [ false ; ndb as usize ] ;
13701377
1378+ for & ax in & axes_a {
1379+ is_contracted_a[ ax as usize ] = true ;
1380+ }
1381+ for & ax in & axes_b {
1382+ is_contracted_b[ ax as usize ] = true ;
1383+ }
1384+
1385+ let contracted_a = axes_a. len ( ) ;
1386+ let contracted_b = axes_b. len ( ) ;
1387+ debug_assert_eq ! ( contracted_a, contracted_b) ;
1388+ let free_a = nda as usize - contracted_a;
1389+ let free_b = ndb as usize - contracted_b;
1390+
1391+ // Permutation axes for a: [non-contracted..., contracted...]
1392+ let mut newaxes_a = Vec :: with_capacity ( nda as usize ) ;
1393+ for i in 0 ..nda as usize {
1394+ if !is_contracted_a[ i] {
1395+ newaxes_a. push ( i) ;
1396+ }
1397+ }
1398+ for & ax in & axes_a {
1399+ newaxes_a. push ( ax as usize ) ;
1400+ }
1401+
1402+ // non-contracted axes for b (indices)
1403+ let mut notin_b = Vec :: with_capacity ( free_b) ;
1404+ for i in 0 ..ndb as usize {
1405+ if !is_contracted_b[ i] {
1406+ notin_b. push ( i) ;
1407+ }
1408+ }
1409+
1410+ // Permutation axes for b: [contracted..., non-contracted...]
1411+ let mut newaxes_b = Vec :: with_capacity ( ndb as usize ) ;
1412+ for & ax in & axes_b {
1413+ newaxes_b. push ( ax as usize ) ;
1414+ }
1415+ newaxes_b. extend ( & notin_b) ;
1416+
1417+ // Output shape: a(non-contracted) ⧺ b(non-contracted)
1418+ let mut out_shape = Vec :: with_capacity ( free_a + free_b) ;
1419+ for i in 0 ..nda as usize {
1420+ if !is_contracted_a[ i] {
1421+ out_shape. push ( ashape[ i] ) ;
1422+ }
1423+ }
1424+ for & ax in & notin_b {
1425+ out_shape. push ( bshape[ ax] ) ;
1426+ }
1427+
1428+ // Matrix dims
1429+ let m = newaxes_a[ ..free_a]
1430+ . iter ( )
1431+ . fold ( 1 , |p, & ax| p * ashape[ ax] ) ;
1432+ let k = axes_a
1433+ . iter ( )
1434+ . fold ( 1 , |p, & ax| p * ashape[ ax as usize ] ) ;
1435+ let n = notin_b
1436+ . iter ( )
1437+ . fold ( 1 , |p, & ax| p * bshape[ ax] ) ;
1438+
1439+ // Permute + standard layout (keep temporaries named to satisfy lifetimes)
13711440 let a_dyn = a. view ( ) . into_dimensionality :: < IxDyn > ( ) . unwrap ( ) ;
13721441 let b_dyn = b. view ( ) . into_dimensionality :: < IxDyn > ( ) . unwrap ( ) ;
13731442
@@ -1377,6 +1446,7 @@ where
13771446 let b_perm = b_dyn. permuted_axes ( IxDyn ( & newaxes_b) ) ;
13781447 let b_std = b_perm. as_standard_layout ( ) ;
13791448
1449+ // Reshape to 2D, multiply, and reshape back
13801450 let a2 = a_std
13811451 . into_shape_with_order ( Ix2 ( m, k) )
13821452 . expect ( "reshaping a to 2D" ) ;
@@ -1386,9 +1456,6 @@ where
13861456
13871457 let c2 = a2. dot ( & b2) ;
13881458
1389- let mut out_shape: Vec < usize > = notin_a. iter ( ) . map ( |& ax| a. shape ( ) [ ax] ) . collect ( ) ;
1390- out_shape. extend ( notin_b. iter ( ) . map ( |& ax| b. shape ( ) [ ax] ) ) ;
1391-
13921459 c2. into_shape_with_order ( IxDyn ( & out_shape) ) . unwrap ( )
13931460}
13941461
@@ -1404,7 +1471,7 @@ where
14041471 type Output = ArrayD < A > ;
14051472
14061473 #[ track_caller]
1407- fn tensordot ( & self , rhs : & ArrayBase < S2 , D2 > , axes : AxisSpec ) -> Self :: Output
1474+ fn tensordot ( & self , rhs : & ArrayBase < S2 , D2 > , axes : & AxisSpec ) -> Self :: Output
14081475 {
14091476 tensordot_impl :: < A , S , S2 , D1 , D2 > ( self , rhs, axes)
14101477 }
@@ -1421,7 +1488,7 @@ where
14211488 type Output = ArrayD < A > ;
14221489
14231490 #[ track_caller]
1424- fn tensordot ( & self , rhs : & ArrayRef < A , D2 > , axes : AxisSpec ) -> Self :: Output
1491+ fn tensordot ( & self , rhs : & ArrayRef < A , D2 > , axes : & AxisSpec ) -> Self :: Output
14251492 {
14261493 let rhs_view: ArrayBase < ViewRepr < & A > , D2 > = rhs. view ( ) ;
14271494 tensordot_impl :: < A , S , ViewRepr < & A > , D1 , D2 > ( self , & rhs_view, axes)
@@ -1439,7 +1506,7 @@ where
14391506 type Output = ArrayD < A > ;
14401507
14411508 #[ track_caller]
1442- fn tensordot ( & self , rhs : & ArrayBase < S , D2 > , axes : AxisSpec ) -> Self :: Output
1509+ fn tensordot ( & self , rhs : & ArrayBase < S , D2 > , axes : & AxisSpec ) -> Self :: Output
14431510 {
14441511 let self_view: ArrayBase < ViewRepr < & A > , D1 > = self . view ( ) ;
14451512 tensordot_impl :: < A , ViewRepr < & A > , S , D1 , D2 > ( & self_view, rhs, axes)
@@ -1459,7 +1526,7 @@ mod tensordot_tests
14591526 let a = ArrayD :: from_shape_vec ( IxDyn ( & [ 3 , 4 , 5 ] ) , ( 0 ..60 ) . collect :: < Vec < _ > > ( ) ) . unwrap ( ) ;
14601527 let b = ArrayD :: from_shape_vec ( IxDyn ( & [ 4 , 3 , 2 ] ) , ( 0 ..24 ) . collect :: < Vec < _ > > ( ) ) . unwrap ( ) ;
14611528
1462- let c: ArrayD < i32 > = tensordot ( & a, & b, AxisSpec :: Pair ( vec ! [ 1 , 0 ] , vec ! [ 0 , 1 ] ) ) ;
1529+ let c: ArrayD < i32 > = tensordot ( & a, & b, & AxisSpec :: Pair ( vec ! [ 1 , 0 ] , vec ! [ 0 , 1 ] ) ) ;
14631530
14641531 // Expected shape: [5, 2]
14651532 assert_eq ! (
@@ -1494,7 +1561,7 @@ mod tensordot_tests
14941561 let b = ArrayD :: from_shape_vec ( IxDyn ( & [ 2 , 2 ] ) , vec ! [ 10 , 20 , 30 , 40 ] ) . unwrap ( ) ;
14951562
14961563 // Contract over 2 axes
1497- let c: ArrayD < i32 > = tensordot ( & a, & b, AxisSpec :: Num ( 2 ) ) ;
1564+ let c: ArrayD < i32 > = tensordot ( & a, & b, & AxisSpec :: Num ( 2 ) ) ;
14981565
14991566 assert_eq ! (
15001567 c. shape( ) ,
0 commit comments