diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 5d09abe..6322ede 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -988,6 +988,20 @@ object TensorOps: val indices = ev.indices Tensor(Jax.jnp.transpose(tensor.jaxValue, indices.toPythonProxy)) + /** Splits the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits. + * + * @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1]) + * @return a sequence of tensors resulting from the split, each with the specified axis removed + */ + def unstack[L: Label, R <: Tuple](unstackAxis: Axis[L])(using + labels: Labels[T], + ev: AxisRemover[T, L, R], + labelR: Labels[R] + ): Seq[Tensor[R, V]] = + val axisIdx = ev.index + val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]] + unstacked.map(x => Tensor[R, V](x)) + def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using labels: Labels[T], axisIndex: AxisIndex[T, splitL] diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala index fbf0c4c..8bd7dd5 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala @@ -440,3 +440,111 @@ class TensorOpsStructureSuite extends DimwitTest: val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))) t.vapply(Axis[B])(_.roll(shift = 1)) shouldEqual (Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(3.0f, 1.0f, 2.0f), Array(6.0f, 4.0f, 5.0f)))) t.vapply(Axis[A])(_.roll(shift = 1)) shouldEqual (Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(4.0f, 5.0f, 6.0f), Array(1.0f, 2.0f, 3.0f)))) + + describe("chunk function"): + + it("chunk axis into equal parts"): + val t = Tensor2(Axis[A], Axis[B]).fromArray( + Array( + Array(1.0f, 2.0f, 3.0f, 4.0f), + Array(5.0f, 6.0f, 7.0f, 8.0f) + ) + ) + + val chunks = t.chunk(Axis[B], chunkSize = 2) + chunks should have size 2 + chunks(0).axes shouldBe List("A", "B") + chunks(0).shape(Axis[B]) shouldBe 2 + chunks(0) should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray( + Array(Array(1.0f, 2.0f), Array(5.0f, 6.0f)) + )) + + chunks(1).shape(Axis[B]) shouldBe 2 + chunks(1) should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray( + Array(Array(3.0f, 4.0f), Array(7.0f, 8.0f)) + )) + + describe("unstack function"): + + it("unstack the first axis"): + val t = Tensor2(Axis[A], Axis[B]).fromArray( + Array( + Array(1.0f, 2.0f), + Array(3.0f, 4.0f), + Array(5.0f, 6.0f) + ) + ) + + val unstacked = t.unstack(Axis[A]) + + // Since A has size 3, unstacking should yield 3 tensors of 1D shape B + unstacked should have size 3 + unstacked.foreach(_.axes shouldBe List("B")) + + unstacked(0) should approxEqual(Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f))) + unstacked(1) should approxEqual(Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f))) + unstacked(2) should approxEqual(Tensor1(Axis[B]).fromArray(Array(5.0f, 6.0f))) + + it("unstack ∘ stack is identity"): + val t1 = Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f)) + val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f)) + val stacked = stack(Seq(t1, t2), Axis[A]) + val unstacked = stacked.unstack(Axis[A]) + + unstacked(0) should approxEqual(t1) + unstacked(1) should approxEqual(t2) + + describe("stack function"): + + it("stack sequence of tensors into a new first axis"): + val t1 = Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f)) + val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f)) + + val stacked = stack(Seq(t1, t2), Axis[A]) + + stacked.axes shouldBe List("A", "B") + stacked.shape(Axis[A]) shouldBe 2 + stacked.shape(Axis[B]) shouldBe 2 + stacked should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray( + Array( + Array(1.0f, 2.0f), + Array(3.0f, 4.0f) + ) + )) + + it("stack sequence of tensors after an existing axis"): + // t1 and t2 have shapes: A=2, B=2 + val t1 = Tensor2(Axis[A], Axis[B]).fromArray( + Array( + Array(1.0f, 2.0f), + Array(3.0f, 4.0f) + ) + ) + val t2 = Tensor2(Axis[A], Axis[B]).fromArray( + Array( + Array(5.0f, 6.0f), + Array(7.0f, 8.0f) + ) + ) + + // Stacking along new axis C, strictly after A. Expected axes: A, C, B + val stacked = stack(Seq(t1, t2), newAxis = Axis[C], afterAxis = Axis[A]) + + stacked.axes shouldBe List("A", "C", "B") + stacked.shape(Axis[A]) shouldBe 2 + stacked.shape(Axis[C]) shouldBe 2 + stacked.shape(Axis[B]) shouldBe 2 + stacked should approxEqual(Tensor3(Axis[A], Axis[C], Axis[B]).fromArray( + Array( + // A = 0 + Array( + Array(1.0f, 2.0f), // C = 0 (from t1) + Array(5.0f, 6.0f) // C = 1 (from t2) + ), + // A = 1 + Array( + Array(3.0f, 4.0f), // C = 0 (from t1) + Array(7.0f, 8.0f) // C = 1 (from t2) + ) + ) + ))