Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,20 @@ object TensorOps:
val indices = ev.indices
Tensor(Jax.jnp.transpose(tensor.jaxValue, indices.toPythonProxy))

/** Splits the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits.
*
* @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1])
* @return a sequence of tensors resulting from the split, each with the specified axis removed
*/
def unstack[L: Label, R <: Tuple](unstackAxis: Axis[L])(using
labels: Labels[T],
ev: AxisRemover[T, L, R],
labelR: Labels[R]
): Seq[Tensor[R, V]] =
val axisIdx = ev.index
val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]]
unstacked.map(x => Tensor[R, V](x))

def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using
labels: Labels[T],
axisIndex: AxisIndex[T, splitL]
Expand Down
108 changes: 108 additions & 0 deletions core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,111 @@ class TensorOpsStructureSuite extends DimwitTest:
val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
t.vapply(Axis[B])(_.roll(shift = 1)) shouldEqual (Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(3.0f, 1.0f, 2.0f), Array(6.0f, 4.0f, 5.0f))))
t.vapply(Axis[A])(_.roll(shift = 1)) shouldEqual (Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(4.0f, 5.0f, 6.0f), Array(1.0f, 2.0f, 3.0f))))

describe("chunk function"):

it("chunk axis into equal parts"):
val t = Tensor2(Axis[A], Axis[B]).fromArray(
Array(
Array(1.0f, 2.0f, 3.0f, 4.0f),
Array(5.0f, 6.0f, 7.0f, 8.0f)
)
)

val chunks = t.chunk(Axis[B], chunkSize = 2)
chunks should have size 2
chunks(0).axes shouldBe List("A", "B")
chunks(0).shape(Axis[B]) shouldBe 2
chunks(0) should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray(
Array(Array(1.0f, 2.0f), Array(5.0f, 6.0f))
))

chunks(1).shape(Axis[B]) shouldBe 2
chunks(1) should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray(
Array(Array(3.0f, 4.0f), Array(7.0f, 8.0f))
))

describe("unstack function"):

it("unstack the first axis"):
val t = Tensor2(Axis[A], Axis[B]).fromArray(
Array(
Array(1.0f, 2.0f),
Array(3.0f, 4.0f),
Array(5.0f, 6.0f)
)
)

val unstacked = t.unstack(Axis[A])

// Since A has size 3, unstacking should yield 3 tensors of 1D shape B
unstacked should have size 3
unstacked.foreach(_.axes shouldBe List("B"))

unstacked(0) should approxEqual(Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f)))
unstacked(1) should approxEqual(Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f)))
unstacked(2) should approxEqual(Tensor1(Axis[B]).fromArray(Array(5.0f, 6.0f)))

it("unstack ∘ stack is identity"):
val t1 = Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f))
val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f))
val stacked = stack(Seq(t1, t2), Axis[A])
val unstacked = stacked.unstack(Axis[A])

unstacked(0) should approxEqual(t1)
unstacked(1) should approxEqual(t2)

describe("stack function"):

it("stack sequence of tensors into a new first axis"):
val t1 = Tensor1(Axis[B]).fromArray(Array(1.0f, 2.0f))
val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f))

val stacked = stack(Seq(t1, t2), Axis[A])

stacked.axes shouldBe List("A", "B")
stacked.shape(Axis[A]) shouldBe 2
stacked.shape(Axis[B]) shouldBe 2
stacked should approxEqual(Tensor2(Axis[A], Axis[B]).fromArray(
Array(
Array(1.0f, 2.0f),
Array(3.0f, 4.0f)
)
))

it("stack sequence of tensors after an existing axis"):
// t1 and t2 have shapes: A=2, B=2
val t1 = Tensor2(Axis[A], Axis[B]).fromArray(
Array(
Array(1.0f, 2.0f),
Array(3.0f, 4.0f)
)
)
val t2 = Tensor2(Axis[A], Axis[B]).fromArray(
Array(
Array(5.0f, 6.0f),
Array(7.0f, 8.0f)
)
)

// Stacking along new axis C, strictly after A. Expected axes: A, C, B
val stacked = stack(Seq(t1, t2), newAxis = Axis[C], afterAxis = Axis[A])

stacked.axes shouldBe List("A", "C", "B")
stacked.shape(Axis[A]) shouldBe 2
stacked.shape(Axis[C]) shouldBe 2
stacked.shape(Axis[B]) shouldBe 2
stacked should approxEqual(Tensor3(Axis[A], Axis[C], Axis[B]).fromArray(
Array(
// A = 0
Array(
Array(1.0f, 2.0f), // C = 0 (from t1)
Array(5.0f, 6.0f) // C = 1 (from t2)
),
// A = 1
Array(
Array(3.0f, 4.0f), // C = 0 (from t1)
Array(7.0f, 8.0f) // C = 1 (from t2)
)
)
))
Loading