From b861fcb9f8df29d693a5e0c97a624fc8bba146f8 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Fri, 29 May 2026 21:39:06 +0200 Subject: [PATCH 1/2] add a toArray method --- .../scala/dimwit/tensor/ArrayReader.scala | 42 +++++++++++++++++ core/src/main/scala/dimwit/tensor/DType.scala | 30 +++++++++++++ .../main/scala/dimwit/tensor/HasScalar.scala | 14 ++++++ .../main/scala/dimwit/tensor/TensorOps.scala | 22 +++++++++ .../scala/dimwit/tensor/ToArraySuite.scala | 45 +++++++++++++++++++ 5 files changed, 153 insertions(+) create mode 100644 core/src/main/scala/dimwit/tensor/ArrayReader.scala create mode 100644 core/src/main/scala/dimwit/tensor/HasScalar.scala create mode 100644 core/src/test/scala/dimwit/tensor/ToArraySuite.scala diff --git a/core/src/main/scala/dimwit/tensor/ArrayReader.scala b/core/src/main/scala/dimwit/tensor/ArrayReader.scala new file mode 100644 index 00000000..3462f2df --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/ArrayReader.scala @@ -0,0 +1,42 @@ +package dimwit.tensor + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import me.shadaj.scalapy.py + +object ArrayReader: + + private def readBytes(jaxValue: py.Dynamic): Array[Byte] = + jaxValue.tobytes().as[Seq[Byte]].toArray + + def readBooleanArray(jaxValue: py.Dynamic): Array[Boolean] = + val bytes = readBytes(jaxValue) + Array.tabulate(bytes.length)(i => bytes(i) != 0) + + def readByteArray(jaxValue: py.Dynamic): Array[Byte] = + readBytes(jaxValue) + + def readShortArray(jaxValue: py.Dynamic): Array[Short] = + val bytes = readBytes(jaxValue) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + Array.tabulate(buf.remaining())(buf.get) + + def readIntArray(jaxValue: py.Dynamic): Array[Int] = + val bytes = readBytes(jaxValue) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() + Array.tabulate(buf.remaining())(buf.get) + + def readLongArray(jaxValue: py.Dynamic): Array[Long] = + val bytes = readBytes(jaxValue) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer() + Array.tabulate(buf.remaining())(buf.get) + + def readFloatArray(jaxValue: py.Dynamic): Array[Float] = + val bytes = readBytes(jaxValue) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer() + Array.tabulate(buf.remaining())(buf.get) + + def readDoubleArray(jaxValue: py.Dynamic): Array[Double] = + val bytes = readBytes(jaxValue) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() + Array.tabulate(buf.remaining())(buf.get) diff --git a/core/src/main/scala/dimwit/tensor/DType.scala b/core/src/main/scala/dimwit/tensor/DType.scala index 42e5b86d..8365a456 100644 --- a/core/src/main/scala/dimwit/tensor/DType.scala +++ b/core/src/main/scala/dimwit/tensor/DType.scala @@ -2,7 +2,9 @@ package dimwit.tensor import dimwit.jax.JaxDType import java.nio.ByteBuffer import java.nio.ByteOrder +import me.shadaj.scalapy.py import dimwit.tensor.TensorOps.{IsFloating, IsInteger, IsBoolean} +import dimwit.tensor.HasScalar object DType: @@ -54,6 +56,34 @@ object DType: given boolIsBoolean: IsBoolean[Bool] with def dtype: DType = DType.Bool + given HasScalar[Bool, Boolean] with + def readFlat(jv: py.Dynamic): Array[Boolean] = ArrayReader.readBooleanArray(jv) + def classTag: scala.reflect.ClassTag[Boolean] = scala.reflect.ClassTag.Boolean + + given HasScalar[Int8, Byte] with + def readFlat(jv: py.Dynamic): Array[Byte] = ArrayReader.readByteArray(jv) + def classTag: scala.reflect.ClassTag[Byte] = scala.reflect.ClassTag.Byte + + given HasScalar[Int16, Short] with + def readFlat(jv: py.Dynamic): Array[Short] = ArrayReader.readShortArray(jv) + def classTag: scala.reflect.ClassTag[Short] = scala.reflect.ClassTag.Short + + given HasScalar[Int32, Int] with + def readFlat(jv: py.Dynamic): Array[Int] = ArrayReader.readIntArray(jv) + def classTag: scala.reflect.ClassTag[Int] = scala.reflect.ClassTag.Int + + given HasScalar[Int64, Long] with + def readFlat(jv: py.Dynamic): Array[Long] = ArrayReader.readLongArray(jv) + def classTag: scala.reflect.ClassTag[Long] = scala.reflect.ClassTag.Long + + given HasScalar[Float32, Float] with + def readFlat(jv: py.Dynamic): Array[Float] = ArrayReader.readFloatArray(jv) + def classTag: scala.reflect.ClassTag[Float] = scala.reflect.ClassTag.Float + + given HasScalar[Float64, Double] with + def readFlat(jv: py.Dynamic): Array[Double] = ArrayReader.readDoubleArray(jv) + def classTag: scala.reflect.ClassTag[Double] = scala.reflect.ClassTag.Double + enum DType(val name: String, val size: Int): case BFloat16 extends DType("bfloat16", 2) case Float16 extends DType("float16", 2) diff --git a/core/src/main/scala/dimwit/tensor/HasScalar.scala b/core/src/main/scala/dimwit/tensor/HasScalar.scala new file mode 100644 index 00000000..71a8897d --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/HasScalar.scala @@ -0,0 +1,14 @@ +package dimwit.tensor + +import scala.annotation.implicitNotFound +import me.shadaj.scalapy.py + +@implicitNotFound( + "No Scala type mapping for DType ${V}. Supported: Bool, Int8, Int16, Int32, Int64, Float32, Float64." +) +trait HasScalar[V, X]: + def readFlat(jaxValue: py.Dynamic): Array[X] + def classTag: scala.reflect.ClassTag[X] + +object HasScalar: + def apply[V, X](using ev: HasScalar[V, X]): HasScalar[V, X] = ev diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index c3854459..c87d106c 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -27,6 +27,7 @@ import dimwit.tensor.ShapeTypeHelpers.AxesMerger import dimwit.OnError import dimwit.DType.* import dimwit.DType.given +import dimwit.tensor.HasScalar import Tuple.:* import Tuple.++ @@ -1542,6 +1543,11 @@ object TensorOps: // TODO understand why toPythonCopy is needed and toPythonProxy fails! Tensor(Jax.lax.dynamic_slice(t.jaxValue, Seq(dynamicStart.jaxValue).toPythonCopy, Seq(staticSize).toPythonCopy)) + extension [L, V, X](t: Tensor1[L, V])(using ev: HasScalar[V, X]) + def toArray: Array[X] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + ev.readFlat(t.jaxValue) + object Tensor2Ops: extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) @@ -1550,10 +1556,26 @@ object TensorOps: def transpose: Tensor2[L2, L1, V] = t.transpose(Axis[L2], Axis[L1]) def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = TensorOps.Structural.transpose(t)(axis2, axis1) + extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X]) + def toArray: Array[Array[X]] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + given scala.reflect.ClassTag[X] = ev.classTag + ev.readFlat(t.jaxValue).grouped(t.shape.dimensions(1)).toArray + + object Tensor3Ops: + + extension [L1, L2, L3, V, X](t: Tensor3[L1, L2, L3, V])(using ev: HasScalar[V, X]) + def toArray: Array[Array[Array[X]]] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + given scala.reflect.ClassTag[X] = ev.classTag + val d1 = t.shape.dimensions(1); val d2 = t.shape.dimensions(2) + ev.readFlat(t.jaxValue).grouped(d1 * d2).map(_.grouped(d2).toArray).toArray + export Tensor0Ops.* export ValueOps.* export Tensor1Ops.* export Tensor2Ops.* + export Tensor3Ops.* end TensorOps diff --git a/core/src/test/scala/dimwit/tensor/ToArraySuite.scala b/core/src/test/scala/dimwit/tensor/ToArraySuite.scala new file mode 100644 index 00000000..0b223e46 --- /dev/null +++ b/core/src/test/scala/dimwit/tensor/ToArraySuite.scala @@ -0,0 +1,45 @@ +package dimwit.tensor + +import dimwit.* + +class ToArraySuite extends DimwitTest: + + describe("Tensor1.toArray"): + it("Float32 roundtrip"): + val t = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f)) + t.toArray shouldBe Array(1.0f, 2.0f, 3.0f) + + it("Int32 roundtrip"): + val t = Tensor1(Axis[A]).fromArray(Array(10, 20, 30)) + t.toArray shouldBe Array(10, 20, 30) + + it("Bool roundtrip"): + val t = Tensor1(Axis[A]).fromArray(Array(true, false, true)) + t.toArray shouldBe Array(true, false, true) + + describe("Tensor2.toArray"): + it("Float32 roundtrip"): + val data = Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)) + val t = Tensor2(Axis[A], Axis[B]).fromArray(data) + t.toArray shouldBe data + + it("Int32 roundtrip"): + val data = Array(Array(1, 2, 3), Array(4, 5, 6)) + val t = Tensor2(Axis[A], Axis[B]).fromArray(data) + t.toArray shouldBe data + + describe("Tensor3.toArray"): + it("Float32 roundtrip"): + val data = Array( + Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)), + Array(Array(5.0f, 6.0f), Array(7.0f, 8.0f)) + ) + val t = Tensor(Shape3(Axis[A] -> 2, Axis[B] -> 2, Axis[C] -> 2)).fromArray( + data.flatten.flatten + ) + t.toArray shouldBe data + + describe("toArray with filled tensors"): + it("fill value is reflected in array"): + val t = Tensor(Shape2(Axis[A] -> 3, Axis[B] -> 2)).fill(7.0f) + t.toArray shouldBe Array(Array(7.0f, 7.0f), Array(7.0f, 7.0f), Array(7.0f, 7.0f)) From ddef73f64019706fbebb1ac19cf2a6d0b4079c74 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sat, 30 May 2026 12:07:23 +0200 Subject: [PATCH 2/2] add some documentation --- .../main/scala/dimwit/tensor/ArrayReader.scala | 2 ++ core/src/main/scala/dimwit/tensor/HasScalar.scala | 14 ++++++++++++++ core/src/main/scala/dimwit/tensor/TensorOps.scala | 15 +++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/core/src/main/scala/dimwit/tensor/ArrayReader.scala b/core/src/main/scala/dimwit/tensor/ArrayReader.scala index 3462f2df..42b0b2b2 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayReader.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayReader.scala @@ -4,6 +4,8 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import me.shadaj.scalapy.py +/** Utility object for reading flat arrays of scalar values from JAX tensors. + */ object ArrayReader: private def readBytes(jaxValue: py.Dynamic): Array[Byte] = diff --git a/core/src/main/scala/dimwit/tensor/HasScalar.scala b/core/src/main/scala/dimwit/tensor/HasScalar.scala index 71a8897d..727f4f2f 100644 --- a/core/src/main/scala/dimwit/tensor/HasScalar.scala +++ b/core/src/main/scala/dimwit/tensor/HasScalar.scala @@ -6,8 +6,22 @@ import me.shadaj.scalapy.py @implicitNotFound( "No Scala type mapping for DType ${V}. Supported: Bool, Int8, Int16, Int32, Int64, Float32, Float64." ) + +/** Type class to map a DType to a Scala type for + * reading scalar values from JAX tensors. + * + * @tparam V DType (e.g., Float32, Int32) + * @tparam X Corresponding Scala type (e.g., Float, Int) + */ trait HasScalar[V, X]: + + /** read a flat array of scalar values from a JAX tensor value. + */ def readFlat(jaxValue: py.Dynamic): Array[X] + + /** ClassTag for the Scala type, used for array creation + * and pattern matching. + */ def classTag: scala.reflect.ClassTag[X] object HasScalar: diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index c87d106c..28ae0a89 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -1544,6 +1544,11 @@ object TensorOps: Tensor(Jax.lax.dynamic_slice(t.jaxValue, Seq(dynamicStart.jaxValue).toPythonCopy, Seq(staticSize).toPythonCopy)) extension [L, V, X](t: Tensor1[L, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor1 to a Scala Array. + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ def toArray: Array[X] = require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") ev.readFlat(t.jaxValue) @@ -1557,6 +1562,11 @@ object TensorOps: def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = TensorOps.Structural.transpose(t)(axis2, axis1) extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor2 to a nested Scala Array (Array of Arrays). + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ def toArray: Array[Array[X]] = require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") given scala.reflect.ClassTag[X] = ev.classTag @@ -1565,6 +1575,11 @@ object TensorOps: object Tensor3Ops: extension [L1, L2, L3, V, X](t: Tensor3[L1, L2, L3, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor3 to a nested Scala Array (Array of Arrays of Arrays). + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ def toArray: Array[Array[Array[X]]] = require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") given scala.reflect.ClassTag[X] = ev.classTag