Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions core/src/main/scala/dimwit/tensor/ArrayReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dimwit.tensor

import java.nio.ByteBuffer
import java.nio.ByteOrder
import me.shadaj.scalapy.py

/** Utility object for reading flat arrays of scalar values from JAX tensors.
*/
object ArrayReader:

private def readBytes(jaxValue: py.Dynamic): Array[Byte] =
jaxValue.tobytes().as[Seq[Byte]].toArray

def readBooleanArray(jaxValue: py.Dynamic): Array[Boolean] =
val bytes = readBytes(jaxValue)
Array.tabulate(bytes.length)(i => bytes(i) != 0)

def readByteArray(jaxValue: py.Dynamic): Array[Byte] =
readBytes(jaxValue)

def readShortArray(jaxValue: py.Dynamic): Array[Short] =
val bytes = readBytes(jaxValue)
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
Array.tabulate(buf.remaining())(buf.get)

def readIntArray(jaxValue: py.Dynamic): Array[Int] =
val bytes = readBytes(jaxValue)
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer()
Array.tabulate(buf.remaining())(buf.get)

def readLongArray(jaxValue: py.Dynamic): Array[Long] =
val bytes = readBytes(jaxValue)
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer()
Array.tabulate(buf.remaining())(buf.get)

def readFloatArray(jaxValue: py.Dynamic): Array[Float] =
val bytes = readBytes(jaxValue)
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer()
Array.tabulate(buf.remaining())(buf.get)

def readDoubleArray(jaxValue: py.Dynamic): Array[Double] =
val bytes = readBytes(jaxValue)
val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer()
Array.tabulate(buf.remaining())(buf.get)
30 changes: 30 additions & 0 deletions core/src/main/scala/dimwit/tensor/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package dimwit.tensor
import dimwit.jax.JaxDType
import java.nio.ByteBuffer
import java.nio.ByteOrder
import me.shadaj.scalapy.py
import dimwit.tensor.TensorOps.{IsFloating, IsInteger, IsBoolean}
import dimwit.tensor.HasScalar

object DType:

Expand Down Expand Up @@ -54,6 +56,34 @@ object DType:
given boolIsBoolean: IsBoolean[Bool] with
def dtype: DType = DType.Bool

given HasScalar[Bool, Boolean] with
def readFlat(jv: py.Dynamic): Array[Boolean] = ArrayReader.readBooleanArray(jv)
def classTag: scala.reflect.ClassTag[Boolean] = scala.reflect.ClassTag.Boolean

given HasScalar[Int8, Byte] with
def readFlat(jv: py.Dynamic): Array[Byte] = ArrayReader.readByteArray(jv)
def classTag: scala.reflect.ClassTag[Byte] = scala.reflect.ClassTag.Byte

given HasScalar[Int16, Short] with
def readFlat(jv: py.Dynamic): Array[Short] = ArrayReader.readShortArray(jv)
def classTag: scala.reflect.ClassTag[Short] = scala.reflect.ClassTag.Short

given HasScalar[Int32, Int] with
def readFlat(jv: py.Dynamic): Array[Int] = ArrayReader.readIntArray(jv)
def classTag: scala.reflect.ClassTag[Int] = scala.reflect.ClassTag.Int

given HasScalar[Int64, Long] with
def readFlat(jv: py.Dynamic): Array[Long] = ArrayReader.readLongArray(jv)
def classTag: scala.reflect.ClassTag[Long] = scala.reflect.ClassTag.Long

given HasScalar[Float32, Float] with
def readFlat(jv: py.Dynamic): Array[Float] = ArrayReader.readFloatArray(jv)
def classTag: scala.reflect.ClassTag[Float] = scala.reflect.ClassTag.Float

given HasScalar[Float64, Double] with
def readFlat(jv: py.Dynamic): Array[Double] = ArrayReader.readDoubleArray(jv)
def classTag: scala.reflect.ClassTag[Double] = scala.reflect.ClassTag.Double

enum DType(val name: String, val size: Int):
case BFloat16 extends DType("bfloat16", 2)
case Float16 extends DType("float16", 2)
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/scala/dimwit/tensor/HasScalar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dimwit.tensor

import scala.annotation.implicitNotFound
import me.shadaj.scalapy.py

@implicitNotFound(
"No Scala type mapping for DType ${V}. Supported: Bool, Int8, Int16, Int32, Int64, Float32, Float64."
)

/** Type class to map a DType to a Scala type for
* reading scalar values from JAX tensors.
*
* @tparam V DType (e.g., Float32, Int32)
* @tparam X Corresponding Scala type (e.g., Float, Int)
*/
trait HasScalar[V, X]:

/** read a flat array of scalar values from a JAX tensor value.
*/
def readFlat(jaxValue: py.Dynamic): Array[X]

/** ClassTag for the Scala type, used for array creation
* and pattern matching.
*/
def classTag: scala.reflect.ClassTag[X]

object HasScalar:
def apply[V, X](using ev: HasScalar[V, X]): HasScalar[V, X] = ev
37 changes: 37 additions & 0 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import dimwit.tensor.ShapeTypeHelpers.AxesMerger
import dimwit.OnError
import dimwit.DType.*
import dimwit.DType.given
import dimwit.tensor.HasScalar

import Tuple.:*
import Tuple.++
Expand Down Expand Up @@ -1542,6 +1543,16 @@ object TensorOps:
// TODO understand why toPythonCopy is needed and toPythonProxy fails!
Tensor(Jax.lax.dynamic_slice(t.jaxValue, Seq(dynamicStart.jaxValue).toPythonCopy, Seq(staticSize).toPythonCopy))

extension [L, V, X](t: Tensor1[L, V])(using ev: HasScalar[V, X])
/** Converts a Tensor1 to a Scala Array.
* The user must ensure that the tensor is not a JAX Tracer
* (i.e., it is not part of a JAX computation graph) before calling this method,
* otherwise a runtime error will occur.
*/
def toArray: Array[X] =
require(!t.isTracer, "Cannot convert a JAX Tracer to an array.")
ev.readFlat(t.jaxValue)

object Tensor2Ops:

extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V])
Expand All @@ -1550,10 +1561,36 @@ object TensorOps:
def transpose: Tensor2[L2, L1, V] = t.transpose(Axis[L2], Axis[L1])
def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = TensorOps.Structural.transpose(t)(axis2, axis1)

extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X])
/** Converts a Tensor2 to a nested Scala Array (Array of Arrays).
* The user must ensure that the tensor is not a JAX Tracer
* (i.e., it is not part of a JAX computation graph) before calling this method,
* otherwise a runtime error will occur.
*/
def toArray: Array[Array[X]] =
require(!t.isTracer, "Cannot convert a JAX Tracer to an array.")
given scala.reflect.ClassTag[X] = ev.classTag
ev.readFlat(t.jaxValue).grouped(t.shape.dimensions(1)).toArray

object Tensor3Ops:

extension [L1, L2, L3, V, X](t: Tensor3[L1, L2, L3, V])(using ev: HasScalar[V, X])
/** Converts a Tensor3 to a nested Scala Array (Array of Arrays of Arrays).
* The user must ensure that the tensor is not a JAX Tracer
* (i.e., it is not part of a JAX computation graph) before calling this method,
* otherwise a runtime error will occur.
*/
def toArray: Array[Array[Array[X]]] =
require(!t.isTracer, "Cannot convert a JAX Tracer to an array.")
given scala.reflect.ClassTag[X] = ev.classTag
val d1 = t.shape.dimensions(1); val d2 = t.shape.dimensions(2)
ev.readFlat(t.jaxValue).grouped(d1 * d2).map(_.grouped(d2).toArray).toArray

export Tensor0Ops.*
export ValueOps.*
export Tensor1Ops.*
export Tensor2Ops.*
export Tensor3Ops.*

end TensorOps

Expand Down
45 changes: 45 additions & 0 deletions core/src/test/scala/dimwit/tensor/ToArraySuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package dimwit.tensor

import dimwit.*

class ToArraySuite extends DimwitTest:

describe("Tensor1.toArray"):
it("Float32 roundtrip"):
val t = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f))
t.toArray shouldBe Array(1.0f, 2.0f, 3.0f)

it("Int32 roundtrip"):
val t = Tensor1(Axis[A]).fromArray(Array(10, 20, 30))
t.toArray shouldBe Array(10, 20, 30)

it("Bool roundtrip"):
val t = Tensor1(Axis[A]).fromArray(Array(true, false, true))
t.toArray shouldBe Array(true, false, true)

describe("Tensor2.toArray"):
it("Float32 roundtrip"):
val data = Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f))
val t = Tensor2(Axis[A], Axis[B]).fromArray(data)
t.toArray shouldBe data

it("Int32 roundtrip"):
val data = Array(Array(1, 2, 3), Array(4, 5, 6))
val t = Tensor2(Axis[A], Axis[B]).fromArray(data)
t.toArray shouldBe data

describe("Tensor3.toArray"):
it("Float32 roundtrip"):
val data = Array(
Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)),
Array(Array(5.0f, 6.0f), Array(7.0f, 8.0f))
)
val t = Tensor(Shape3(Axis[A] -> 2, Axis[B] -> 2, Axis[C] -> 2)).fromArray(
data.flatten.flatten
)
t.toArray shouldBe data

describe("toArray with filled tensors"):
it("fill value is reflected in array"):
val t = Tensor(Shape2(Axis[A] -> 3, Axis[B] -> 2)).fill(7.0f)
t.toArray shouldBe Array(Array(7.0f, 7.0f), Array(7.0f, 7.0f), Array(7.0f, 7.0f))
Loading