Skip to content

Conversation

@carlosgmartin
Copy link
Contributor

Deprecates the redundant function optax.global_norm in favor of optax.tree.norm.

Split into a new PR from #1365.

@vroulet
Copy link
Collaborator

vroulet commented Jun 27, 2025

Did you check that the hlos match?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jun 28, 2025

$ prog="import jax, optax; print(jax.jit(optax.global_norm).lower(0.).as_text())"
$ git checkout main; py -c $prog > ~/desktop/main.txt
$ git checkout deprecate_optax_global_norm; py -c $prog > ~/desktop/deprecate_optax_global_norm.txt
$ git diff ~/desktop/main.txt ~/desktop/deprecate_optax_global_norm.txt

Diff:

 module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
   func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
     %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
-    %1 = stablehlo.convert %0 : tensor<f32>
     %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
-    %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = stablehlo.multiply %cst, %cst : tensor<f32>
+    %2 = stablehlo.add %0, %1 : tensor<f32>
+    %3 = stablehlo.convert %2 : tensor<f32>
     %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
-    %3 = stablehlo.add %cst_0, %2 : tensor<f32>
-    %4 = stablehlo.sqrt %3 : tensor<f32>
-    return %4 : tensor<f32>
+    %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+    %5 = stablehlo.add %cst_1, %4 : tensor<f32>
+    %6 = stablehlo.sqrt %5 : tensor<f32>
+    return %6 : tensor<f32>
   }
 }

Separately:

Old
module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %1 = stablehlo.convert %0 : tensor<f32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.add %cst_0, %2 : tensor<f32>
    %4 = stablehlo.sqrt %3 : tensor<f32>
    return %4 : tensor<f32>
  }
}
New
module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.multiply %cst, %cst : tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.convert %2 : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %5 = stablehlo.add %cst_1, %4 : tensor<f32>
    %6 = stablehlo.sqrt %5 : tensor<f32>
    return %6 : tensor<f32>
  }
}

For reference, here are the current implementations of global_norm and tree_norm.

A difference is that the current implementation of global_norm uses numerics.abs_sq:

def abs_sq(x):
  return (x.conj() * x).real

whereas tree_norm uses _tree_math._square:

def _square(leaf):
  return jnp.square(leaf.real) + jnp.square(leaf.imag)

We should probably pick the most efficient of these two for both functions.

Based on HLO size, it looks like abs_sq is slightly more efficient than _square. The latter adds an extra add instruction and an extra multiply instruction.

@emilyfertig
Copy link
Collaborator

On a pytree input, HLO is as follows:

nested_updates = {
    'a': jnp.array([2.0, 4.0], dtype=jnp.float32),
    'b': jnp.array([3.0, 5.0], dtype=jnp.float32),
}
make_hlo = lambda f, *args: jax.jit(f).lower(*args).compile().as_text()
print(make_hlo(linear_algebra.global_norm, nested_updates))
print(make_hlo(optax.tree.norm, nested_updates))

For global_norm:

HloModule jit_global_norm, is_scheduled=true, entry_computation_layout={(f32[2]{0}, f32[2]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%region_1.2 (reduce_sum.10: f32[], reduce_sum.11: f32[]) -> f32[] {
  %reduce_sum.10 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.11 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.12 = f32[] add(%reduce_sum.10, %reduce_sum.11), metadata={op_name="jit(global_norm)/reduce_sum" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=38 source_end_line=38 source_column=10 source_end_column=37}
}

%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
  %reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(global_norm)/reduce_sum" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=38 source_end_line=38 source_column=10 source_end_column=37}
}

%fused_computation (param_0.3: f32[2], param_1.5: f32[2]) -> f32[] {
  %param_1.5 = f32[2]{0} parameter(1)
  %mul.1 = f32[2]{0} multiply(%param_1.5, %param_1.5), metadata={op_name="jit(global_norm)/mul" source_file="third_party/py/optax/_src/numerics.py" source_line=45 source_end_line=45 source_column=10 source_end_column=22}
  %constant.0 = f32[] constant(0)
  %reduce_sum.1 = f32[] reduce(%mul.1, %constant.0), dimensions={0}, to_apply=%region_0.1, metadata={op_name="jit(global_norm)/reduce_sum" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=38 source_end_line=38 source_column=10 source_end_column=37}
  %param_0.3 = f32[2]{0} parameter(0)
  %mul.0 = f32[2]{0} multiply(%param_0.3, %param_0.3), metadata={op_name="jit(global_norm)/mul" source_file="third_party/py/optax/_src/numerics.py" source_line=45 source_end_line=45 source_column=10 source_end_column=22}
  %reduce_sum.0 = f32[] reduce(%mul.0, %constant.0), dimensions={0}, to_apply=%region_1.2, metadata={op_name="jit(global_norm)/reduce_sum" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=38 source_end_line=38 source_column=10 source_end_column=37}
  %add.0 = f32[] add(%reduce_sum.1, %reduce_sum.0), metadata={op_name="jit(global_norm)/add" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=38 source_end_line=38 source_column=6 source_end_column=72}
  ROOT %sqrt.0 = f32[] sqrt(%add.0), metadata={op_name="jit(global_norm)/sqrt" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=37 source_end_line=39 source_column=9 source_end_column=3}
}

ENTRY %main.3 (updates__a__.1: f32[2], updates__b__.1: f32[2]) -> f32[] {
  %updates__a__.1 = f32[2]{0} parameter(0), metadata={op_name="updates[\'a\']"}
  %updates__b__.1 = f32[2]{0} parameter(1), metadata={op_name="updates[\'b\']"}
  ROOT %add_sqrt_fusion = f32[] fusion(%updates__b__.1, %updates__a__.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(global_norm)/sqrt" source_file="third_party/py/optax/_src/linear_algebra.py" source_line=37 source_end_line=39 source_column=9 source_end_column=3}
}

For tree_norm:

HloModule jit_tree_norm, is_scheduled=true, entry_computation_layout={(f32[2]{0}, f32[2]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%region_1.2 (reduce_sum.10: f32[], reduce_sum.11: f32[]) -> f32[] {
  %reduce_sum.10 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.11 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.12 = f32[] add(%reduce_sum.10, %reduce_sum.11), metadata={op_name="jit(tree_norm)/reduce_sum" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=179 source_end_line=179 source_column=9 source_end_column=36}
}

%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
  %reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(tree_norm)/reduce_sum" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=179 source_end_line=179 source_column=9 source_end_column=36}
}

%fused_computation (param_0.3: f32[2], param_1.5: f32[2]) -> f32[] {
  %param_1.5 = f32[2]{0} parameter(1)
  %square.1 = f32[2]{0} multiply(%param_1.5, %param_1.5), metadata={op_name="jit(tree_norm)/square" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=248 source_end_line=248 source_column=9 source_end_column=30}
  %constant.0 = f32[] constant(0)
  %reduce_sum.1 = f32[] reduce(%square.1, %constant.0), dimensions={0}, to_apply=%region_0.1, metadata={op_name="jit(tree_norm)/reduce_sum" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=179 source_end_line=179 source_column=9 source_end_column=36}
  %param_0.3 = f32[2]{0} parameter(0)
  %square.0 = f32[2]{0} multiply(%param_0.3, %param_0.3), metadata={op_name="jit(tree_norm)/square" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=248 source_end_line=248 source_column=9 source_end_column=30}
  %reduce_sum.0 = f32[] reduce(%square.0, %constant.0), dimensions={0}, to_apply=%region_1.2, metadata={op_name="jit(tree_norm)/reduce_sum" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=179 source_end_line=179 source_column=9 source_end_column=36}
  %add.0 = f32[] add(%reduce_sum.1, %reduce_sum.0), metadata={op_name="jit(tree_norm)/add" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=180 source_end_line=180 source_column=9 source_end_column=59}
  ROOT %sqrt.0 = f32[] sqrt(%add.0), metadata={op_name="jit(tree_norm)/sqrt" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=267 source_end_line=267 source_column=44 source_end_column=60}
}

ENTRY %main.3 (tree__a__.1: f32[2], tree__b__.1: f32[2]) -> f32[] {
  %tree__a__.1 = f32[2]{0} parameter(0), metadata={op_name="tree[\'a\']"}
  %tree__b__.1 = f32[2]{0} parameter(1), metadata={op_name="tree[\'b\']"}
  ROOT %add_sqrt_fusion = f32[] fusion(%tree__b__.1, %tree__a__.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(tree_norm)/sqrt" source_file="third_party/py/optax/tree_utils/_tree_math.py" source_line=267 source_end_line=267 source_column=44 source_end_column=60}
}

These are identical except tree.norm uses square where global_norm uses mul. I think this change should be negligible (or square might be slightly more efficient?) so we can merge this.

@emilyfertig emilyfertig self-assigned this Nov 10, 2025
@copybara-service copybara-service bot merged commit 5db134c into google-deepmind:main Nov 10, 2025
15 checks passed
@carlosgmartin carlosgmartin deleted the deprecate_optax_global_norm branch November 10, 2025 19:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants