Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Optional, Sequence

import jax
import orbax.checkpoint as ocp
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring
Expand Down Expand Up @@ -134,6 +135,19 @@ def record_event(self, event: measurement.EventType, *args, **kwargs):
)
# pylint: enable=try-except-raise

def create_checkpoint_logger(self) -> Optional[ocp.logging.CloudLogger]:
try:
logging.info("Creating a Goodput checkpoint logger.")
return ocp.logging.CloudLogger(
options=ocp.logging.CloudLoggerOptions(
job_name=self._job_name,
logger_name=self._logger_name,
)
)
except Exception as e: # pylint: disable=broad-exception-caught
logging.warning("Failed to create Goodput checkpoint logger: %s", e, exc_info=True)
return None

@contextlib.contextmanager
def _maybe_monitor_goodput(self, *args, **kwargs):
"""Monitor cumulative goodput if enabled.
Expand Down Expand Up @@ -221,35 +235,32 @@ def record(self, event: measurement.Event, *args, **kwargs):
"""
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
if self._recorder is None:
cfg: GoodputRecorder.Config = self.config
if jax.process_index() == 0:
logging.info("Lazily instantiating goodput recorder.")
self._recorder = goodput.GoodputRecorder(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
job_name=self._job_name,
logger_name=self._logger_name,
logging_enabled=(jax.process_index() == 0),
)

if event == measurement.Event.START_JOB:
self._recorder.record_job_start_time(*args, **kwargs)
elif event == measurement.Event.END_JOB:
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_start_time(*args, **kwargs)
elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_end_time(*args, **kwargs)
start_method_name = f"record_{event.value}_start_time"
end_method_name = f"record_{event.value}_end_time"

record_event_start = getattr(self._recorder, start_method_name, None)
record_event_end = getattr(self._recorder, end_method_name, None)

if record_event_start:
try:
record_event_start(*args, **kwargs)
except RuntimeError as e:
logging.warning(
"Failed to record start of event %s. Error: %s", event.value, e, exc_info=True
)
# pylint: disable=try-except-raise
try:
yield # Run the user code in the context
except Exception:
raise
else:
logging.log_first_n(
logging.WARNING,
Expand Down
43 changes: 43 additions & 0 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,46 @@ def test_maybe_monitor_all(
else:
mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called()
mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called()

@mock.patch("jax.process_index", return_value=0)
def test_create_checkpoint_logger_success(self, _):
"""Tests that create_checkpoint_logger creates a CloudLogger with correct config."""
cfg = GoodputRecorder.default_config().set(
name="test-job",
upload_dir="/test",
upload_interval=30,
)
recorder = GoodputRecorder(cfg)

with mock.patch("orbax.checkpoint.logging.CloudLogger") as mock_logger_cls:
mock_logger_instance = mock_logger_cls.return_value
logger = recorder.create_checkpoint_logger()

mock_logger_cls.assert_called_once()
self.assertIs(logger, mock_logger_instance)

_, kwargs = mock_logger_cls.call_args
options = kwargs["options"]
self.assertEqual(options.job_name, "test-job")
self.assertEqual(options.logger_name, "goodput_logger_test-job")

@mock.patch("jax.process_index", return_value=0)
def test_create_checkpoint_logger_failure(self, _):
"""Tests that create_checkpoint_logger logs a warning on failure and returns None."""
cfg = GoodputRecorder.default_config().set(
name="fail-job",
upload_dir="/test",
upload_interval=30,
)
recorder = GoodputRecorder(cfg)

with mock.patch(
"orbax.checkpoint.logging.CloudLogger", side_effect=RuntimeError("TestError")
) as mock_logger_cls, mock.patch.object(logging, "warning") as mock_warning:
logger = recorder.create_checkpoint_logger()
self.assertIsNone(logger)
mock_logger_cls.assert_called_once()
mock_warning.assert_called_once()
self.assertIn(
"Failed to create Goodput checkpoint logger", mock_warning.call_args[0][0]
)
6 changes: 5 additions & 1 deletion axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from absl import logging

from axlearn.common import utils
from axlearn.common import measurement, utils
from axlearn.common.checkpointer import (
STEP_NUM_DIGITS,
STEP_PREFIX,
Expand Down Expand Up @@ -232,6 +232,9 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
step_prefix=STEP_PREFIX,
step_format_fixed_length=STEP_NUM_DIGITS,
)
self._checkpoint_logger = None
if measurement.global_recorder:
self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger()
self._manager = ocp.CheckpointManager(
directory=cfg.dir,
options=ocp.CheckpointManagerOptions(
Expand All @@ -255,6 +258,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
restore_concurrent_gb=cfg.max_concurrent_restore_gb,
),
},
logger=self._checkpoint_logger,
)

def _get_spec(self, *, step: int, state: Nested[Any]) -> Nested[Any]:
Expand Down
6 changes: 5 additions & 1 deletion axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax.experimental.array_serialization import serialization

from axlearn.common import file_system as fs
from axlearn.common import utils, utils_spmd
from axlearn.common import measurement, utils, utils_spmd
from axlearn.common.checkpointer import (
STEP_NUM_DIGITS,
STEP_PREFIX,
Expand Down Expand Up @@ -667,6 +667,9 @@ def _composite_save_policy(*, step: int, evaler_summaries: dict[str, Any]):
# See comments of _eval_summaries in `OrbaxCheckpointer`.
self._eval_summaries = None
self._reached_preemption = False
self._checkpoint_logger = None
if measurement.global_recorder:
self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger()

# pylint: disable-next=redefined-builtin
def ckpt_dir(self, step: int, dir: Optional[str] = None) -> str:
Expand Down Expand Up @@ -731,6 +734,7 @@ def _orbax_save_fn(
cleanup_tmp_directories=True,
enable_async_checkpointing=True,
),
logger=self._checkpoint_logger,
)
return self._tensor_manager

Expand Down
28 changes: 27 additions & 1 deletion axlearn/common/checkpointer_orbax_emergency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tempfile
from contextlib import ExitStack, closing
from typing import Optional
from unittest import mock

import jax
import numpy as np
Expand All @@ -18,7 +19,7 @@
from absl.testing import parameterized
from jax import numpy as jnp

from axlearn.common import utils_spmd
from axlearn.common import measurement, utils_spmd
from axlearn.common.checkpointer_orbax_emergency import (
OrbaxEmergencyCheckpointer,
_dump_process_info,
Expand Down Expand Up @@ -299,3 +300,28 @@ def start_processes(reverse_process_id: bool = False):
finally:
for p in processes:
p.kill()

@mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_runtime_to_distributed_ids")
@mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_distributed_to_device_ids")
def test_emergency_checkpointer_initializes_logger_from_global_recorder(
self, mock_init_runtime, mock_init_device_ids
): # pylint: disable=unused-argument
"""Tests OrbaxEmergencyCheckpointer initializes _checkpoint_logger."""
with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object(
measurement, "global_recorder", mock.MagicMock()
) as mock_recorder:
mock_logger = mock.MagicMock()
mock_recorder.create_checkpoint_logger.return_value = mock_logger

cfg = OrbaxEmergencyCheckpointer.default_config().set(
name="test_logger",
trainer_dir=temp_dir,
dir=temp_dir,
local_dir=temp_dir,
replica_axis_index=0,
)

ckpt: OrbaxEmergencyCheckpointer = cfg.instantiate(parent=None)

mock_recorder.create_checkpoint_logger.assert_called_once()
self.assertEqual(ckpt._checkpoint_logger, mock_logger)
21 changes: 20 additions & 1 deletion axlearn/common/checkpointer_orbax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import os
import tempfile
from typing import Sequence
from unittest import mock

import jax
import orbax.checkpoint as ocp
from jax import numpy as jnp
from jax.experimental import mesh_utils

from axlearn.common import test_utils
from axlearn.common import measurement, test_utils
from axlearn.common.checkpointer import read_index_file
from axlearn.common.checkpointer_orbax import OrbaxCheckpointer

Expand Down Expand Up @@ -52,3 +53,21 @@ def test_index(self):
),
)
self.assertEqual(ref_index, test_index["index"])

def test_initializes_checkpoint_logger_from_global_recorder(self):
"""Tests that OrbaxCheckpointer initializes _checkpoint_logger if global_recorder is set."""
with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object(
measurement, "global_recorder", mock.MagicMock()
) as mock_recorder:
mock_logger = mock.MagicMock(spec=ocp.logging.CloudLogger)
mock_recorder.create_checkpoint_logger.return_value = mock_logger

ckpt = (
OrbaxCheckpointer.default_config()
.set(name="test", dir=temp_dir)
.instantiate(parent=None)
)

# Ensure create_checkpoint_logger was called and the logger was set.
mock_recorder.create_checkpoint_logger.assert_called_once()
self.assertEqual(ckpt._checkpoint_logger, mock_logger)
14 changes: 11 additions & 3 deletions axlearn/common/launch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Utilities to launch a trainer."""

import contextlib
import json
import os
from typing import Any, Optional
Expand Down Expand Up @@ -128,8 +129,7 @@ def get_trainer_config(
return trainer_config


def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
measurement.record_event(measurement.Event.START_JOB)
def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any:
trainer_config_debug_string = trainer_config.debug_string()
logging.info("Trainer config:\n%s", trainer_config_debug_string)
if jax.process_index() == 0:
Expand All @@ -150,5 +150,13 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
output = trainer.run(prng_key)
measurement.record_event(measurement.Event.END_JOB)
return output


def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
recorder = measurement.global_recorder
job_events_manager = (
recorder.record_event(measurement.EventType.JOB) if recorder else contextlib.nullcontext()
)
with job_events_manager:
return _run_trainer_impl(trainer_config)
1 change: 0 additions & 1 deletion axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def main(_):
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
launch_trainer.run_trainer(trainer_config)


Expand Down
4 changes: 4 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def maybe_monitor_all(self):
"""
yield

def create_checkpoint_logger(self) -> Optional[object]:
"""Optionally returns a fully functional and independent checkpoint logger."""
return None


_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down
3 changes: 3 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def test_initialize(self, recorder_type, expected):
# Ensure that maybe_monitor_all does not fail (just enter and exit context).
with measurement.global_recorder.maybe_monitor_all():
pass

# Ensure that create_checkpoint_logger does not crash.
measurement.global_recorder.create_checkpoint_logger()
Loading
Loading