diff --git a/core/src/main/scala/dimwit/autodiff/FloatTree.scala b/core/src/main/scala/dimwit/autodiff/FloatTree.scala index f1d0a75..32128c1 100644 --- a/core/src/main/scala/dimwit/autodiff/FloatTree.scala +++ b/core/src/main/scala/dimwit/autodiff/FloatTree.scala @@ -47,6 +47,34 @@ object FloatTree: def map[NewV](f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => Tensor[T, NewV])): P = tt.map(p, [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]]) + /** Maps a function over the TensorTree along with the structural path, + * providing knowledge that tensors are of type V + */ + def mapWithName[NewV](f: [T <: Tuple] => Labels[T] ?=> ((String, Tensor[T, V]) => Tensor[T, NewV]), path: String = ""): P = + tt.mapWithName( + p, + [T <: Tuple, V0] => (n: Labels[T]) ?=> (pth: String, t: Tensor[T, V0]) => f[T](using n)(pth, t.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]], + path + ) + + /** Foreach over the TensorTree, providing knowledge that tensors are of type V + */ + def foreach(f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => Unit)): Unit = + tt.foreach( + p, + [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]]) + ) + + /** Foreach over the TensorTree along with the structural path, + * providing knowledge that tensors are of type V + */ + def foreachWithName(f: [T <: Tuple] => Labels[T] ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = + tt.foreachWithName( + p, + [T <: Tuple, V0] => (n: Labels[T]) ?=> (pth: String, t: Tensor[T, V0]) => f[T](using n)(pth, t.asInstanceOf[Tensor[T, V]]), + path + ) + /** Zipmaps a function over the TensorTree, as for tensor tree, * but provides knowledge that tensors are of type V */ @@ -57,6 +85,9 @@ object FloatTree: [T <: Tuple, V0] => (n: Labels[T]) ?=> (t1: Tensor[T, V0], t2: Tensor[T, V0]) => f[T](using n)(t1.asInstanceOf[Tensor[T, V]], t2.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]] ) + def mapLeaves[A](f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => A)): Iterator[A] = + tt.mapLeaves(p, [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]])) + /** Arithmetic and math operations for tensor trees of floating-point types. */ object ops: diff --git a/core/src/main/scala/dimwit/autodiff/Grad.scala b/core/src/main/scala/dimwit/autodiff/Grad.scala index 472e3d5..aec2759 100644 --- a/core/src/main/scala/dimwit/autodiff/Grad.scala +++ b/core/src/main/scala/dimwit/autodiff/Grad.scala @@ -31,9 +31,23 @@ object Grad: given [T](using ev: TensorTree[T]): TensorTree[Grad[T]] with def map(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> Tensor[U, V] => Tensor[U, V]): Grad[T] = Grad(ev.map(g, f)) + + def mapWithName(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> ((String, Tensor[U, V]) => Tensor[U, V]), path: String = ""): Grad[T] = + Grad(ev.mapWithName(g, f, path)) + + def foreach(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> (Tensor[U, V] => Unit)): Unit = + ev.foreach(g, f) + + def foreachWithName(g: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> ((String, Tensor[U, V]) => Unit), path: String = ""): Unit = + ev.foreachWithName(g, f, path) + def zipMap(g1: Grad[T], g2: Grad[T], f: [U <: Tuple, V] => Labels[U] ?=> (Tensor[U, V], Tensor[U, V]) => Tensor[U, V]): Grad[T] = Grad(ev.zipMap(g1, g2, f)) + + def mapLeaves[A](p: Grad[T], f: [T <: Tuple, V] => (x: Labels[T]) ?=> (t: Tensor[T, V]) => A): Iterator[A] = ev.mapLeaves(p, f) + def toPyTree(g: Grad[T]): Jax.PyAny = ev.toPyTree(g) + def fromPyTree(pyVal: Jax.PyAny): Grad[T] = Grad(ev.fromPyTree(pyVal)) // FloatTree witness for gradient math (++, --, scale, etc.) diff --git a/core/src/main/scala/dimwit/autodiff/TensorTree.scala b/core/src/main/scala/dimwit/autodiff/TensorTree.scala index 84804ac..3e403ee 100644 --- a/core/src/main/scala/dimwit/autodiff/TensorTree.scala +++ b/core/src/main/scala/dimwit/autodiff/TensorTree.scala @@ -27,6 +27,20 @@ trait TensorTree[P]: */ def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P + /** Similar to `map`, but also provides the string path (e.g., "layer1.weights") to the tensor. + */ + def mapWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): P + + def mapLeaves[A](p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] + + /** A polymorphic foreach over the tensor tree. + */ + def foreach(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit + + /** Similar to `foreach`, but also provides the string path (e.g., "layer1.weights") to the tensor. + */ + def foreachWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit + /** A polymorphic zipMap over two tensor trees of the same structure. * * @param p1 The first structure containing the tensors @@ -53,9 +67,27 @@ object TensorTree: // extends TensorTreeLowPriority: def map(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => Tensor[T, V2])): Tensor[Q, V] = import TensorOps.retag f[Q, V](using n)(t.retag[Q](using n)) + + def mapWithName(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((String, Tensor[T, V2]) => Tensor[T, V2]), path: String = ""): Tensor[Q, V] = + import TensorOps.retag + f[Q, V](using n)(path, t.retag[Q](using n)) + + def mapLeaves[A](t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => A)): Iterator[A] = + import TensorOps.retag + Iterator(f[Q, V](using n)(t.retag[Q](using n))) + + def foreach(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => Unit)): Unit = + import TensorOps.retag + f[Q, V](using n)(t.retag[Q](using n)) + + def foreachWithName(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((String, Tensor[T, V2]) => Unit), path: String = ""): Unit = + import TensorOps.retag + f[Q, V](using n)(path, t.retag[Q](using n)) + def zipMap(p1: Tensor[Q, V], p2: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> ((Tensor[T, V2], Tensor[T, V2]) => Tensor[T, V2])): Tensor[Q, V] = import TensorOps.retag f[Q, V](using n)(p1.retag[Q](using n), p2.retag[Q](using n)) + def toPyTree(p: Tensor[Q, V]): Jax.PyAny = p.jaxValue def fromPyTree(pyVal: Jax.PyAny): Tensor[Q, V] = Tensor(pyVal.as[Jax.PyDynamic]) @@ -64,50 +96,93 @@ object TensorTree: // extends TensorTreeLowPriority: */ given TensorTree[Unit] with def map(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): Unit = () + def mapWithName(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): Unit = () + def mapLeaves[A](p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] = Iterator.empty + def foreach(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit = () + def foreachWithName(p: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = () def zipMap(p1: Unit, p2: Unit, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): Unit = () def toPyTree(p: Unit): Jax.PyAny = py.Dynamic.global.None def fromPyTree(pyVal: Jax.PyAny): Unit = () - /** Instance for a tuple of two tensors - */ - given tupleInstance[A, B](using ta: TensorTree[A], tb: TensorTree[B]): TensorTree[(A, B)] with - def map(p: (A, B), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): (A, B) = - (ta.map(p._1, f), tb.map(p._2, f)) - def zipMap(p1: (A, B), p2: (A, B), f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): (A, B) = - (ta.zipMap(p1._1, p2._1, f), tb.zipMap(p1._2, p2._2, f)) - def toPyTree(p: (A, B)): Jax.PyAny = - py.Dynamic.global.tuple(Seq(ta.toPyTree(p._1), tb.toPyTree(p._2)).toPythonProxy) - def fromPyTree(pyVal: Jax.PyAny): (A, B) = + /** Instance for a tuple of two tensors */ + given tupleInstance[P1, P2](using t1: TensorTree[P1], t2: TensorTree[P2]): TensorTree[(P1, P2)] with + def map(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): (P1, P2) = + (t1.map(p._1, f), t2.map(p._2, f)) + + def mapWithName(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): (P1, P2) = + val p1Path = if path.isEmpty then "_1" else s"$path._1" + val p2Path = if path.isEmpty then "_2" else s"$path._2" + (t1.mapWithName(p._1, f, p1Path), t2.mapWithName(p._2, f, p2Path)) + + def mapLeaves[A](p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] = + t1.mapLeaves(p._1, f) ++ t2.mapLeaves(p._2, f) + + def foreach(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit = + t1.foreach(p._1, f) + t2.foreach(p._2, f) + + def foreachWithName(p: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = + val p1Path = if path.isEmpty then "_1" else s"$path._1" + val p2Path = if path.isEmpty then "_2" else s"$path._2" + t1.foreachWithName(p._1, f, p1Path) + t2.foreachWithName(p._2, f, p2Path) + + def zipMap(p1: (P1, P2), p2: (P1, P2), f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): (P1, P2) = + (t1.zipMap(p1._1, p2._1, f), t2.zipMap(p1._2, p2._2, f)) + + def toPyTree(p: (P1, P2)): Jax.PyAny = + py.Dynamic.global.tuple(Seq(t1.toPyTree(p._1), t2.toPyTree(p._2)).toPythonProxy) + + def fromPyTree(pyVal: Jax.PyAny): (P1, P2) = val pyTuple = pyVal.as[py.Dynamic] - (ta.fromPyTree(pyTuple.bracketAccess(0)), tb.fromPyTree(pyTuple.bracketAccess(1))) + (t1.fromPyTree(pyTuple.bracketAccess(0)), t2.fromPyTree(pyTuple.bracketAccess(1))) /** Instance for a list of tensor trees */ - given listInstance[A](using ta: TensorTree[A]): TensorTree[List[A]] with - def map(l: List[A], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): List[A] = - l.map(elem => ta.map(elem, f)) - def zipMap(l1: List[A], l2: List[A], f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): List[A] = - l1.zip(l2).map { case (e1, e2) => ta.zipMap(e1, e2, f) } - def toPyTree(l: List[A]): Jax.PyAny = - val pyItems = l.map(a => ta.toPyTree(a)) + given listInstance[P](using tp: TensorTree[P]): TensorTree[List[P]] with + def map(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): List[P] = + l.map(elem => tp.map(elem, f)) + + def mapWithName(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): List[P] = + l.zipWithIndex.map: (elem, i) => + val nextPath = if path.isEmpty then s"[$i]" else s"$path[$i]" + tp.mapWithName(elem, f, nextPath) + + def mapLeaves[A](l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] = + l.iterator.flatMap(elem => tp.mapLeaves(elem, f)) + + def foreach(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit = + l.foreach(elem => tp.foreach(elem, f)) + + def foreachWithName(l: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = + l.zipWithIndex.foreach: (elem, i) => + val nextPath = if path.isEmpty then s"[$i]" else s"$path[$i]" + tp.foreachWithName(elem, f, nextPath) + + def zipMap(l1: List[P], l2: List[P], f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): List[P] = + l1.zip(l2).map { case (e1, e2) => tp.zipMap(e1, e2, f) } + + def toPyTree(l: List[P]): Jax.PyAny = + val pyItems = l.map(a => tp.toPyTree(a)) py.Dynamic.global.list(pyItems.toPythonProxy) - def fromPyTree(pyVal: Jax.PyAny): List[A] = + + def fromPyTree(pyVal: Jax.PyAny): List[P] = val pyList = pyVal.as[py.Dynamic] val len = py.Dynamic.global.len(pyList).as[Int] - List.tabulate(len)(i => ta.fromPyTree(pyList.bracketAccess(i))) + List.tabulate(len)(i => tp.fromPyTree(pyList.bracketAccess(i))) /** automatically derive a TensorTree instance for any case class (or product type) * whose fields all have TensorTree instances. - * The derived instance will map over each field using the - * corresponding field's TensorTree instance, and preserve the overall structure of the case class. */ inline given derived[P <: Product](using m: Mirror.ProductOf[P]): TensorTree[P] = val elemInstances = summonAll[Tuple.Map[m.MirroredElemTypes, TensorTree]] val instances = elemInstances.toList.asInstanceOf[List[TensorTree[Any]]] - derivedImpl(instances, m) + val fieldNames = constValueTuple[m.MirroredElemLabels].toList.map(_.toString) + derivedImpl(instances, fieldNames, m) private def derivedImpl[P <: Product]( instances: List[TensorTree[Any]], + fieldNames: List[String], m: Mirror.ProductOf[P] ): TensorTree[P] = new TensorTree[P]: def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P = @@ -118,6 +193,37 @@ object TensorTree: // extends TensorTreeLowPriority: case (elem, inst) => inst.map(elem, f) m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray)) + def mapWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): P = + val inputs = p.productIterator.toList + val mappedElems = inputs + .zip(instances) + .zip(fieldNames) + .map: + case ((elem, inst), fieldName) => + val nextPath = if path.isEmpty then fieldName else s"$path.$fieldName" + inst.mapWithName(elem, f, nextPath) + m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray)) + + def mapLeaves[A](p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A)): Iterator[A] = + val inputs = p.productIterator + inputs.zip(instances.iterator).flatMap: + case (elem, inst) => inst.mapLeaves(elem, f) + + def foreach(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Unit)): Unit = + val inputs = p.productIterator + inputs.zip(instances.iterator).foreach: + case (elem, inst) => inst.foreach(elem, f) + + def foreachWithName(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = + val inputs = p.productIterator.toList + inputs + .zip(instances) + .zip(fieldNames) + .foreach: + case ((elem, inst), fieldName) => + val nextPath = if path.isEmpty then fieldName else s"$path.$fieldName" + inst.foreachWithName(elem, f, nextPath) + def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P = val inputs1 = p1.productIterator.toList val inputs2 = p2.productIterator.toList diff --git a/core/src/main/scala/dimwit/random/Random.scala b/core/src/main/scala/dimwit/random/Random.scala index ab586ce..492c22d 100644 --- a/core/src/main/scala/dimwit/random/Random.scala +++ b/core/src/main/scala/dimwit/random/Random.scala @@ -95,12 +95,24 @@ object Random: * so this instance allows them to be used * seamlessly with autodiff and JIT compilation. */ + given TensorTree[Key] with // map is really a noop as it is not a tensor def map(p: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): Key = p + + def mapWithName(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> ((String, Tensor[T, V]) => Tensor[T, V]), path: String = ""): Key = p + + def foreach(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> (Tensor[T, V] => Unit)): Unit = () + + def foreachWithName(p: Key, f: [T <: Tuple, V] => Labels[T] ?=> ((String, Tensor[T, V]) => Unit), path: String = ""): Unit = () + // zipmap is also a noop, just return the first key def zipMap(p1: Key, p2: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): Key = p1 + + def mapLeaves[A](p: Key, f: [T <: Tuple, V] => (Labels[T]) ?=> Tensor[T, V] => A): Iterator[A] = Iterator.empty + def toPyTree(p: Key): Jax.PyAny = p.jaxKey + def fromPyTree(pyVal: Jax.PyAny): Key = Key(pyVal.as[Jax.PyDynamic]) /** Create a random key from an integer seed */ diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index e7b4341..ad4fd8f 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -230,3 +230,16 @@ class FloatTensorTreeSuite extends DimwitTest: res.b.shape shouldBe params.b.shape res.w.approxElementEquals(Tensor.like(res.w).fill(99f)).all.item shouldBe true res.b.approxElementEquals(Tensor.like(res.b).fill(99f)).all.item shouldBe true + + describe("mapLeaves"): + + it("Calculate norm over tree structure"): + trait Norm derives Label + case class Params(val w: Tensor1[A, Float32], val b: Tensor0[Float32]) + val params1 = Params(Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), Tensor0(0)) + val leaveNorms = stack( + params1.mapLeaves([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x.norm).toSeq, + newAxis = Axis[Norm] + ) + val norm = leaveNorms.norm + norm should approxEqual((params1.w.pow(2).sum + params1.b.pow(2).sum).sqrt) diff --git a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala index 134f850..7e8587c 100644 --- a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala @@ -40,3 +40,145 @@ class TensorTreeSuite extends DimwitTest: val res = ftTree.zipMap(params1, params2, [T <: Tuple, V] => (labels: Labels[T]) ?=> (x1: Tensor[T, V], x2: Tensor[T, V]) => maximum(x1, x2)) res.w1 should approxEqual(maximum(params1.w1, params2.w1)) res.b1 should equal(maximum(params1.b1, params2.b1)) + + describe("mapLeaves"): + + it("1-level case class"): + case class Data( + val numbers: Tensor1[A, Float32], + val counts: Tensor1[A, Int32], + val flags: Tensor1[A, Bool] + ) + val params = Data( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor1(Axis[A]).fromArray(Array(1, 2, 3)), + Tensor1(Axis[A]).fromArray(Array(true, false, true)) + ) + val tree = TensorTree[Data] + val leavesCount = tree.mapLeaves(params, [T <: Tuple, V] => (labels: Labels[T]) ?=> (x: Tensor[T, V]) => 1).sum + leavesCount should equal(3) + + it("nested structures (tuple of case classes)"): + case class Params(val w1: Tensor1[A, Float32], val b1: Tensor0[Int32]) + val params1 = Params(Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), Tensor0(0)) + val params2 = Params(Tensor1(Axis[A]).fromArray(Array(0.4f, 0.5f, 0.6f)), Tensor0(1)) + val tree = TensorTree[(Params, Params)] + val leaves = tree.mapLeaves((params1, params2), [T <: Tuple, V] => (labels: Labels[T]) ?=> (x: Tensor[T, V]) => "leaf").toList + leaves should equal(List("leaf", "leaf", "leaf", "leaf")) + leaves.size should equal(4) + + it("list of structures"): + case class Params(w: Tensor0[Float32]) + val paramsList = List( + Params(Tensor0(1.0f)), + Params(Tensor0(2.0f)), + Params(Tensor0(3.0f)) + ) + val tree = TensorTree[List[Params]] + val leavesCount = tree.mapLeaves(paramsList, [T <: Tuple, V] => (labels: Labels[T]) ?=> (x: Tensor[T, V]) => 1).sum + leavesCount should equal(3) + + describe("foreach"): + it("1-level case class"): + case class Data( + val numbers: Tensor1[A, Float32], + val counts: Tensor1[A, Int32] + ) + val params = Data( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor1(Axis[A]).fromArray(Array(1, 2, 3)) + ) + val tree = TensorTree[Data] + var visitCount = 0 + tree.foreach( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (x: Tensor[T, V]) => + visitCount += 1 + ) + visitCount should equal(2) + + describe("mapWithName"): + it("1-level case class"): + case class Params( + val w1: Tensor1[A, Float32], + val b1: Tensor0[Int32] + ) + val params = Params( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor0(0) + ) + val tree = TensorTree[Params] + var paths = Vector.empty[String] + val tree2 = tree.mapWithName( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (path: String, x: Tensor[T, V]) => + paths = paths :+ path + x + ) + paths.toList should equal(List("w1", "b1")) + tree2.w1 should approxEqual(params.w1) + tree2.b1 should equal(params.b1) + + it("nested structures (lists and tuples)"): + case class Model( + val layers: List[Tensor0[Float32]], + val extra: (Tensor0[Int32], Tensor0[Int32]) + ) + val params = Model( + List(Tensor0(1.0f), Tensor0(2.0f)), + (Tensor0(3), Tensor0(4)) + ) + val tree = TensorTree[Model] + var paths = Vector.empty[String] + tree.mapWithName( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (path: String, x: Tensor[T, V]) => + paths = paths :+ path + x + ) + paths.toList should equal(List("layers[0]", "layers[1]", "extra._1", "extra._2")) + + describe("foreachWithName"): + it("1-level case class"): + case class Data( + val numbers: Tensor1[A, Float32], + val counts: Tensor1[A, Int32], + val flags: Tensor1[A, Bool] + ) + val params = Data( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor1(Axis[A]).fromArray(Array(1, 2, 3)), + Tensor1(Axis[A]).fromArray(Array(true, false, true)) + ) + val tree = TensorTree[Data] + var paths = Vector.empty[String] + tree.foreachWithName( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (path: String, x: Tensor[T, V]) => + paths = paths :+ path + ) + paths.toList should equal(List("numbers", "counts", "flags")) + + it("nested structures (case class inside case class)"): + case class Inner(w: Tensor0[Float32]) + case class Outer(inner1: Inner, inner2: Inner) + + val params = Outer(Inner(Tensor0(1.0f)), Inner(Tensor0(2.0f))) + val tree = TensorTree[Outer] + var paths = Vector.empty[String] + tree.foreachWithName( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (path: String, x: Tensor[T, V]) => + paths = paths :+ path + ) + paths.toList should equal(List("inner1.w", "inner2.w"))