-
Notifications
You must be signed in to change notification settings - Fork 277
Deprecate optax.global_norm in favor of optax.tree.norm. #1368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate optax.global_norm in favor of optax.tree.norm. #1368
Conversation
|
Did you check that the hlos match? |
Diff: module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
- %1 = stablehlo.convert %0 : tensor<f32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
- %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = stablehlo.multiply %cst, %cst : tensor<f32>
+ %2 = stablehlo.add %0, %1 : tensor<f32>
+ %3 = stablehlo.convert %2 : tensor<f32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
- %3 = stablehlo.add %cst_0, %2 : tensor<f32>
- %4 = stablehlo.sqrt %3 : tensor<f32>
- return %4 : tensor<f32>
+ %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+ %5 = stablehlo.add %cst_1, %4 : tensor<f32>
+ %6 = stablehlo.sqrt %5 : tensor<f32>
+ return %6 : tensor<f32>
}
}Separately: OldNewFor reference, here are the current implementations of A difference is that the current implementation of def abs_sq(x):
return (x.conj() * x).realwhereas def _square(leaf):
return jnp.square(leaf.real) + jnp.square(leaf.imag)We should probably pick the most efficient of these two for both functions. Based on HLO size, it looks like |
|
On a pytree input, HLO is as follows: For For These are identical except |
Deprecates the redundant function
optax.global_normin favor ofoptax.tree.norm.Split into a new PR from #1365.