Skip to content

Commit b5bf7fc

Browse files
author
Mark Lee
authored
Allows batch_axis_names to be optional. (#1285)
1 parent 89991e8 commit b5bf7fc

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

axlearn/common/trainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class Config(Module.Config):
122122
mesh_axis_names: Required[Sequence[str]] = REQUIRED
123123
# Subset of mesh axis names over which the leaves of the input batch are sharded.
124124
# TODO(markblee): Deprecate this field in favor of `input.input_partitioner`.
125-
batch_axis_names: Union[str, Sequence[str]] = "data"
125+
batch_axis_names: Optional[Union[str, Sequence[str]]] = "data"
126126

127127
# An optional list of (regex, MeshShape) pairs to override the default mesh configuration.
128128
#
@@ -297,11 +297,12 @@ def __init__(
297297
# Create all children within the mesh context so that utils.input_partition_spec() works
298298
# properly.
299299
with self.mesh():
300+
if cfg.batch_axis_names is not None:
301+
cfg.input = maybe_set_config(
302+
cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names)
303+
)
300304
self.input: Input = self._add_child(
301-
"input",
302-
maybe_set_config(
303-
cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names), is_training=True
304-
),
305+
"input", maybe_set_config(cfg.input, is_training=True)
305306
)
306307
# Start from the beginning of the input dataset by default.
307308
self._input_iter = iter(self.input.dataset())
@@ -341,9 +342,10 @@ def __init__(
341342
evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join(
342343
cfg.dir, "summaries", evaler_name
343344
)
344-
maybe_set_config(
345-
evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names)
346-
)
345+
if cfg.batch_axis_names is not None:
346+
maybe_set_config(
347+
evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names)
348+
)
347349
self._evalers[evaler_name] = self._add_child(
348350
evaler_name,
349351
evaler_cfg,
@@ -591,7 +593,7 @@ def run(
591593
input_batch = next(input_iterator)
592594
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
593595
logging.log_first_n(
594-
logging.INFO, "input_batch=%s", 3, utils.shapes(input_batch)
596+
logging.INFO, "host_input_batch=%s", 3, utils.shapes(input_batch)
595597
)
596598

597599
# Stop or start tracing if necessary.
@@ -601,7 +603,7 @@ def run(
601603
self.vlog(3, "Start step %s", self.step)
602604
self._maybe_record_event(measurement.Event.START_STEP, self._step)
603605
output = self._run_step(
604-
utils.host_to_global_device_array(
606+
utils.host_to_global_array(
605607
input_batch,
606608
partition=self._train_step_input_partition_specs(),
607609
),
@@ -1089,6 +1091,7 @@ def _run_step(
10891091
A dict containing 'loss' and 'aux' outputs. If force_run_evals is a set,
10901092
force run the evalers in the set and return 'evaler_summaries' output.
10911093
"""
1094+
logging.log_first_n(logging.INFO, "global_input_batch=%s", 3, utils.shapes(input_batch))
10921095
with jax.profiler.StepTraceAnnotation("train", step_num=self.step):
10931096
run_with_xsc = self._xsc_check_policy and self._xsc_check_policy(self.step)
10941097
compiled_train_step_fn = self._get_compiled_train_step_fn(

axlearn/common/trainer_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,26 @@ def test_input_dispatch_every_other_process(self, multiple: float):
12161216
"""Tests input dispatch with some padding feeds. Requires process_count > 1."""
12171217
self._test_input_dispatch(multiple, backend="tpu")
12181218

1219+
def test_optional_batch_axes(self):
1220+
"""Tests that we can omit batch_axis_names."""
1221+
mesh_shape = (jax.device_count(), 1)
1222+
global_logical_batch_size = mesh_shape[0]
1223+
partition_spec = PartitionSpec("model") # Something other than "data".
1224+
1225+
# Explicitly set a partition spec on input.
1226+
input_cfg = self._dummy_input_checking_input(global_logical_batch_size)
1227+
input_cfg.partition_spec = partition_spec
1228+
1229+
cfg = self._trainer_config(input_cfg)
1230+
cfg.batch_axis_names = None
1231+
cfg.max_step = 3
1232+
cfg.mesh_shape = mesh_shape
1233+
cfg.model = self._dummy_input_checking_model(
1234+
global_logical_batch_size, partition_spec=partition_spec
1235+
)
1236+
trainer: SpmdTrainer = cfg.instantiate(parent=None)
1237+
self.assertEqual(partition_spec, trainer.input.partition_spec)
1238+
12191239

12201240
class SelectMeshConfigTest(test_utils.TestCase):
12211241
def test_select_mesh_config(self):

0 commit comments

Comments
 (0)