@@ -122,7 +122,7 @@ class Config(Module.Config):
122122 mesh_axis_names : Required [Sequence [str ]] = REQUIRED
123123 # Subset of mesh axis names over which the leaves of the input batch are sharded.
124124 # TODO(markblee): Deprecate this field in favor of `input.input_partitioner`.
125- batch_axis_names : Union [str , Sequence [str ]] = "data"
125+ batch_axis_names : Optional [ Union [str , Sequence [str ] ]] = "data"
126126
127127 # An optional list of (regex, MeshShape) pairs to override the default mesh configuration.
128128 #
@@ -297,11 +297,12 @@ def __init__(
297297 # Create all children within the mesh context so that utils.input_partition_spec() works
298298 # properly.
299299 with self .mesh ():
300+ if cfg .batch_axis_names is not None :
301+ cfg .input = maybe_set_config (
302+ cfg .input , partition_spec = PartitionSpec (cfg .batch_axis_names )
303+ )
300304 self .input : Input = self ._add_child (
301- "input" ,
302- maybe_set_config (
303- cfg .input , partition_spec = PartitionSpec (cfg .batch_axis_names ), is_training = True
304- ),
305+ "input" , maybe_set_config (cfg .input , is_training = True )
305306 )
306307 # Start from the beginning of the input dataset by default.
307308 self ._input_iter = iter (self .input .dataset ())
@@ -341,9 +342,10 @@ def __init__(
341342 evaler_cfg .summary_writer .dir = evaler_cfg .summary_writer .dir or os .path .join (
342343 cfg .dir , "summaries" , evaler_name
343344 )
344- maybe_set_config (
345- evaler_cfg .input , partition_spec = PartitionSpec (cfg .batch_axis_names )
346- )
345+ if cfg .batch_axis_names is not None :
346+ maybe_set_config (
347+ evaler_cfg .input , partition_spec = PartitionSpec (cfg .batch_axis_names )
348+ )
347349 self ._evalers [evaler_name ] = self ._add_child (
348350 evaler_name ,
349351 evaler_cfg ,
@@ -591,7 +593,7 @@ def run(
591593 input_batch = next (input_iterator )
592594 self ._maybe_record_event (measurement .Event .END_DATA_LOADING )
593595 logging .log_first_n (
594- logging .INFO , "input_batch =%s" , 3 , utils .shapes (input_batch )
596+ logging .INFO , "host_input_batch =%s" , 3 , utils .shapes (input_batch )
595597 )
596598
597599 # Stop or start tracing if necessary.
@@ -601,7 +603,7 @@ def run(
601603 self .vlog (3 , "Start step %s" , self .step )
602604 self ._maybe_record_event (measurement .Event .START_STEP , self ._step )
603605 output = self ._run_step (
604- utils .host_to_global_device_array (
606+ utils .host_to_global_array (
605607 input_batch ,
606608 partition = self ._train_step_input_partition_specs (),
607609 ),
@@ -1089,6 +1091,7 @@ def _run_step(
10891091 A dict containing 'loss' and 'aux' outputs. If force_run_evals is a set,
10901092 force run the evalers in the set and return 'evaler_summaries' output.
10911093 """
1094+ logging .log_first_n (logging .INFO , "global_input_batch=%s" , 3 , utils .shapes (input_batch ))
10921095 with jax .profiler .StepTraceAnnotation ("train" , step_num = self .step ):
10931096 run_with_xsc = self ._xsc_check_policy and self ._xsc_check_policy (self .step )
10941097 compiled_train_step_fn = self ._get_compiled_train_step_fn (
0 commit comments