Skip to content

Add mapLeaves operation to Tree#108

Merged
benikm91 merged 2 commits into
dimwit-dev:mainfrom
benikm91:add-map-leaves
May 26, 2026
Merged

Add mapLeaves operation to Tree#108
benikm91 merged 2 commits into
dimwit-dev:mainfrom
benikm91:add-map-leaves

Conversation

@benikm91
Copy link
Copy Markdown
Collaborator

Add mapLeaves operation to Tree (similar to jax.tree_util.tree_leaves).

General Motivation: Add reduce operation over a TensorTree structure.
Specific Motivation: Calculate Gradient Norm for Gradient Clipping (GPT-2 in deepwit).

This PR adds mapLeaves to TensorTree and FloatTree, which is similar to jax.tree_util.tree_leaves from JAX.

Difference to JAX: jax.tree_util.tree_leaves just returns a List of Any (without mapping), which makes less sense in a strongly typed language. Therefore, we require a mapping function f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => A) to unify all leaves (tensors) to a shared type A. Reduction across A can then be done on the Iterator[A] object.

Note:

  • We can't return a Tree of A's as the Tree Structure can be a case class of (different-shaped) tensors. So, e.g., Params[A] does not make sense in such cases. Therefore, we (must) return a flat Iterator over leaves.
  • Returning an Iterator without a map function f is not possible, as Tensor[T, V] has to type parameters T and V, which differ across leaves. It would require an Iterator with Polymorphic functions in map, reduce, etc.

@benikm91 benikm91 requested a review from marcelluethi May 21, 2026 15:14
@benikm91
Copy link
Copy Markdown
Collaborator Author

Last commit implements #87 in new TensorTree and FloatTree.

Copy link
Copy Markdown
Contributor

@marcelluethi marcelluethi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. I suppose this supersedes #87, which then could be closed.

@benikm91 benikm91 merged commit a597d5b into dimwit-dev:main May 26, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants