Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions core/src/main/scala/dimwit/autodiff/FloatTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,34 @@ object FloatTree:
def map[NewV](f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => Tensor[T, NewV])): P =
tt.map(p, [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]])

/** Maps a function over the TensorTree along with the structural path,
* providing knowledge that tensors are of type V
*/
def mapWithName[NewV](f: [T <: Tuple] => Labels[T] ?=> ((String, Tensor[T, V]) => Tensor[T, NewV]), path: String = ""): P =
tt.mapWithName(
p,
[T <: Tuple, V0] => (n: Labels[T]) ?=> (pth: String, t: Tensor[T, V0]) => f[T](using n)(pth, t.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]],
path
)

/** Foreach over the TensorTree, providing knowledge that tensors are of type V
*/
def foreach(f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => Unit)): Unit =
tt.foreach(
p,
[T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]])
)

/** Foreach over the TensorTree along with the structural path,
* providing knowledge that tensors are of type V
*/
def foreachWithName(f: [T <: Tuple] => Labels[T] ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit =
tt.foreachWithName(
p,
[T <: Tuple, V0] => (n: Labels[T]) ?=> (pth: String, t: Tensor[T, V0]) => f[T](using n)(pth, t.asInstanceOf[Tensor[T, V]]),
path
)

/** Zipmaps a function over the TensorTree, as for tensor tree,
* but provides knowledge that tensors are of type V
*/
Expand All @@ -57,6 +85,9 @@ object FloatTree:
[T <: Tuple, V0] => (n: Labels[T]) ?=> (t1: Tensor[T, V0], t2: Tensor[T, V0]) => f[T](using n)(t1.asInstanceOf[Tensor[T, V]], t2.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]]
)

def mapLeaves[A](f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => A)): Iterator[A] =
tt.mapLeaves(p, [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]]))

/** Arithmetic and math operations for tensor trees of floating-point types.
*/
object ops:
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/scala/dimwit/autodiff/Grad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,23 @@ object Grad:
given [T](using ev: TensorTree[T]): TensorTree[Grad[T]] with
def map(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> Tensor[U, V] => Tensor[U, V]): Grad[T] =
Grad(ev.map(g, f))

def mapWithName(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> ((String, Tensor[U, V]) => Tensor[U, V]), path: String = ""): Grad[T] =
Grad(ev.mapWithName(g, f, path))

def foreach(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> (Tensor[U, V] => Unit)): Unit =
ev.foreach(g, f)

def foreachWithName(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> ((String, Tensor[U, V]) => Unit), path: String = ""): Unit =
ev.foreachWithName(g, f, path)

def zipMap(g1: Grad[T], g2: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> (Tensor[U, V], Tensor[U, V]) => Tensor[U, V]): Grad[T] =
Grad(ev.zipMap(g1, g2, f))

def mapLeaves[A](p: Grad[T], f: [T <: Tuple, V] => (x: Labels[T]) ?=> (t: Tensor[T, V]) => A): Iterator[A] = ev.mapLeaves(p, f)

def toPyTree(g: Grad[T]): Jax.PyAny = ev.toPyTree(g)

def fromPyTree(pyVal: Jax.PyAny): Grad[T] = Grad(ev.fromPyTree(pyVal))

// FloatTree witness for gradient math (++, --, scale, etc.)
Expand Down
152 changes: 129 additions & 23 deletions core/src/main/scala/dimwit/autodiff/TensorTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ trait TensorTree[P]:
*/
def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P

/** Similar to `map`, but also provides the string path (e.g., "layer1.weights") to the tensor.
*/
def mapWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): P

def mapLeaves[A](p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A]

/** A polymorphic foreach over the tensor tree.
*/
def foreach(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit

/** Similar to `foreach`, but also provides the string path (e.g., "layer1.weights") to the tensor.
*/
def foreachWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit

/** A polymorphic zipMap over two tensor trees of the same structure.
*
* @param p1 The first structure containing the tensors
Expand All @@ -53,9 +67,27 @@ object TensorTree: // extends TensorTreeLowPriority:
def map(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => Tensor[T, V2])): Tensor[Q, V] =
import TensorOps.retag
f[Q, V](using n)(t.retag[Q](using n))

def mapWithName(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((String, Tensor[T, V2]) => Tensor[T, V2]), path: String = ""): Tensor[Q, V] =
import TensorOps.retag
f[Q, V](using n)(path, t.retag[Q](using n))

def mapLeaves[A](t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => A)): Iterator[A] =
import TensorOps.retag
Iterator(f[Q, V](using n)(t.retag[Q](using n)))

def foreach(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => Unit)): Unit =
import TensorOps.retag
f[Q, V](using n)(t.retag[Q](using n))

def foreachWithName(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((String, Tensor[T, V2]) => Unit), path: String = ""): Unit =
import TensorOps.retag
f[Q, V](using n)(path, t.retag[Q](using n))

def zipMap(p1: Tensor[Q, V], p2: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((Tensor[T, V2], Tensor[T, V2]) => Tensor[T, V2])): Tensor[Q, V] =
import TensorOps.retag
f[Q, V](using n)(p1.retag[Q](using n), p2.retag[Q](using n))

def toPyTree(p: Tensor[Q, V]): Jax.PyAny = p.jaxValue
def fromPyTree(pyVal: Jax.PyAny): Tensor[Q, V] = Tensor(pyVal.as[Jax.PyDynamic])

Expand All @@ -64,50 +96,93 @@ object TensorTree: // extends TensorTreeLowPriority:
*/
given TensorTree[Unit] with
def map(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): Unit = ()
def mapWithName(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): Unit = ()
def mapLeaves[A](p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] = Iterator.empty
def foreach(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit = ()
def foreachWithName(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = ()
def zipMap(p1: Unit, p2: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): Unit = ()
def toPyTree(p: Unit): Jax.PyAny = py.Dynamic.global.None
def fromPyTree(pyVal: Jax.PyAny): Unit = ()

/** Instance for a tuple of two tensors
*/
given tupleInstance[A, B](using ta: TensorTree[A], tb: TensorTree[B]): TensorTree[(A, B)] with
def map(p: (A, B), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): (A, B) =
(ta.map(p._1, f), tb.map(p._2, f))
def zipMap(p1: (A, B), p2: (A, B), f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): (A, B) =
(ta.zipMap(p1._1, p2._1, f), tb.zipMap(p1._2, p2._2, f))
def toPyTree(p: (A, B)): Jax.PyAny =
py.Dynamic.global.tuple(Seq(ta.toPyTree(p._1), tb.toPyTree(p._2)).toPythonProxy)
def fromPyTree(pyVal: Jax.PyAny): (A, B) =
/** Instance for a tuple of two tensors */
given tupleInstance[P1, P2](using t1: TensorTree[P1], t2: TensorTree[P2]): TensorTree[(P1, P2)] with
def map(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): (P1, P2) =
(t1.map(p._1, f), t2.map(p._2, f))

def mapWithName(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): (P1, P2) =
val p1Path = if path.isEmpty then "_1" else s"$path._1"
val p2Path = if path.isEmpty then "_2" else s"$path._2"
(t1.mapWithName(p._1, f, p1Path), t2.mapWithName(p._2, f, p2Path))

def mapLeaves[A](p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] =
t1.mapLeaves(p._1, f) ++ t2.mapLeaves(p._2, f)

def foreach(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit =
t1.foreach(p._1, f)
t2.foreach(p._2, f)

def foreachWithName(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit =
val p1Path = if path.isEmpty then "_1" else s"$path._1"
val p2Path = if path.isEmpty then "_2" else s"$path._2"
t1.foreachWithName(p._1, f, p1Path)
t2.foreachWithName(p._2, f, p2Path)

def zipMap(p1: (P1, P2), p2: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): (P1, P2) =
(t1.zipMap(p1._1, p2._1, f), t2.zipMap(p1._2, p2._2, f))

def toPyTree(p: (P1, P2)): Jax.PyAny =
py.Dynamic.global.tuple(Seq(t1.toPyTree(p._1), t2.toPyTree(p._2)).toPythonProxy)

def fromPyTree(pyVal: Jax.PyAny): (P1, P2) =
val pyTuple = pyVal.as[py.Dynamic]
(ta.fromPyTree(pyTuple.bracketAccess(0)), tb.fromPyTree(pyTuple.bracketAccess(1)))
(t1.fromPyTree(pyTuple.bracketAccess(0)), t2.fromPyTree(pyTuple.bracketAccess(1)))

/** Instance for a list of tensor trees
*/
given listInstance[A](using ta: TensorTree[A]): TensorTree[List[A]] with
def map(l: List[A], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): List[A] =
l.map(elem => ta.map(elem, f))
def zipMap(l1: List[A], l2: List[A], f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): List[A] =
l1.zip(l2).map { case (e1, e2) => ta.zipMap(e1, e2, f) }
def toPyTree(l: List[A]): Jax.PyAny =
val pyItems = l.map(a => ta.toPyTree(a))
given listInstance[P](using tp: TensorTree[P]): TensorTree[List[P]] with
def map(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): List[P] =
l.map(elem => tp.map(elem, f))

def mapWithName(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): List[P] =
l.zipWithIndex.map: (elem, i) =>
val nextPath = if path.isEmpty then s"[$i]" else s"$path[$i]"
tp.mapWithName(elem, f, nextPath)

def mapLeaves[A](l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] =
l.iterator.flatMap(elem => tp.mapLeaves(elem, f))

def foreach(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit =
l.foreach(elem => tp.foreach(elem, f))

def foreachWithName(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit =
l.zipWithIndex.foreach: (elem, i) =>
val nextPath = if path.isEmpty then s"[$i]" else s"$path[$i]"
tp.foreachWithName(elem, f, nextPath)

def zipMap(l1: List[P], l2: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): List[P] =
l1.zip(l2).map { case (e1, e2) => tp.zipMap(e1, e2, f) }

def toPyTree(l: List[P]): Jax.PyAny =
val pyItems = l.map(a => tp.toPyTree(a))
py.Dynamic.global.list(pyItems.toPythonProxy)
def fromPyTree(pyVal: Jax.PyAny): List[A] =

def fromPyTree(pyVal: Jax.PyAny): List[P] =
val pyList = pyVal.as[py.Dynamic]
val len = py.Dynamic.global.len(pyList).as[Int]
List.tabulate(len)(i => ta.fromPyTree(pyList.bracketAccess(i)))
List.tabulate(len)(i => tp.fromPyTree(pyList.bracketAccess(i)))

/** automatically derive a TensorTree instance for any case class (or product type)
* whose fields all have TensorTree instances.
* The derived instance will map over each field using the
* corresponding field's TensorTree instance, and preserve the overall structure of the case class.
*/
inline given derived[P <: Product](using m: Mirror.ProductOf[P]): TensorTree[P] =
val elemInstances = summonAll[Tuple.Map[m.MirroredElemTypes, TensorTree]]
val instances = elemInstances.toList.asInstanceOf[List[TensorTree[Any]]]
derivedImpl(instances, m)
val fieldNames = constValueTuple[m.MirroredElemLabels].toList.map(_.toString)
derivedImpl(instances, fieldNames, m)

private def derivedImpl[P <: Product](
instances: List[TensorTree[Any]],
fieldNames: List[String],
m: Mirror.ProductOf[P]
): TensorTree[P] = new TensorTree[P]:
def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P =
Expand All @@ -118,6 +193,37 @@ object TensorTree: // extends TensorTreeLowPriority:
case (elem, inst) => inst.map(elem, f)
m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray))

def mapWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): P =
val inputs = p.productIterator.toList
val mappedElems = inputs
.zip(instances)
.zip(fieldNames)
.map:
case ((elem, inst), fieldName) =>
val nextPath = if path.isEmpty then fieldName else s"$path.$fieldName"
inst.mapWithName(elem, f, nextPath)
m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray))

def mapLeaves[A](p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] =
val inputs = p.productIterator
inputs.zip(instances.iterator).flatMap:
case (elem, inst) => inst.mapLeaves(elem, f)

def foreach(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit =
val inputs = p.productIterator
inputs.zip(instances.iterator).foreach:
case (elem, inst) => inst.foreach(elem, f)

def foreachWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit =
val inputs = p.productIterator.toList
inputs
.zip(instances)
.zip(fieldNames)
.foreach:
case ((elem, inst), fieldName) =>
val nextPath = if path.isEmpty then fieldName else s"$path.$fieldName"
inst.foreachWithName(elem, f, nextPath)

def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P =
val inputs1 = p1.productIterator.toList
val inputs2 = p2.productIterator.toList
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/dimwit/random/Random.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,24 @@ object Random:
* so this instance allows them to be used
* seamlessly with autodiff and JIT compilation.
*/

given TensorTree[Key] with
// map is really a noop as it is not a tensor
def map(p: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): Key = p

def mapWithName(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): Key = p

def foreach(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> (Tensor[T, V] => Unit)): Unit = ()

def foreachWithName(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = ()

// zipmap is also a noop, just return the first key
def zipMap(p1: Key, p2: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): Key = p1

def mapLeaves[A](p: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> Tensor[T, V] => A): Iterator[A] = Iterator.empty

def toPyTree(p: Key): Jax.PyAny = p.jaxKey

def fromPyTree(pyVal: Jax.PyAny): Key = Key(pyVal.as[Jax.PyDynamic])

/** Create a random key from an integer seed */
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,16 @@ class FloatTensorTreeSuite extends DimwitTest:
res.b.shape shouldBe params.b.shape
res.w.approxElementEquals(Tensor.like(res.w).fill(99f)).all.item shouldBe true
res.b.approxElementEquals(Tensor.like(res.b).fill(99f)).all.item shouldBe true

describe("mapLeaves"):

it("Calculate norm over tree structure"):
trait Norm derives Label
case class Params(val w: Tensor1[A, Float32], val b: Tensor0[Float32])
val params1 = Params(Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), Tensor0(0))
val leaveNorms = stack(
params1.mapLeaves([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x.norm).toSeq,
newAxis = Axis[Norm]
)
val norm = leaveNorms.norm
norm should approxEqual((params1.w.pow(2).sum + params1.b.pow(2).sum).sqrt)
Loading
Loading