Skip to content

Commit 5ecc953

Browse files
dipannita08SujeethJinesh
authored andcommitted
Integrate AXLearn with Goodput v12
1 parent ce3a4ca commit 5ecc953

File tree

9 files changed

+586
-461
lines changed

9 files changed

+586
-461
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 124 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
"""Measurement utils for GCP.
44
5+
For detailed documentation and advanced usage, please refer to:
6+
axlearn/docs/05-Goodput-Monitoring.md
7+
58
Example:
69
710
# Enable Goodput when launching an AXLearn training job
@@ -13,10 +16,14 @@
1316
--recorder_spec=name=my-run-with-goodput \
1417
--recorder_spec=upload_dir=my-output-directory/summaries \
1518
--recorder_spec=upload_interval=30 \
16-
--recorder_spec=step_deviation_interval_seconds=30
19+
--recorder_spec=rolling_window_size=86400,259200,432000
1720
1821
"""
1922

23+
import contextlib
24+
import os
25+
from typing import Optional, Sequence
26+
2027
import jax
2128
from absl import flags, logging
2229
from ml_goodput_measurement import goodput
@@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config):
3845
Attributes:
3946
upload_dir: Directory to store metrics for the monitor.
4047
upload_interval: Time interval (seconds) for monitoring uploads.
41-
step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
42-
uploads. -1 to disable step deviation uploads.
48+
See "How to Monitor Cumulative Goodput Metrics" in
49+
docs/05-Goodput-Monitoring.md for more details.
50+
rolling_window_size: A sequence of integers defining the rolling window sizes in
51+
seconds.
52+
See "How to Monitor Rolling Window Goodput Metrics" in
53+
docs/05-Goodput-Monitoring.md for more details.
54+
jax_backend: Jax backend type to infer Pathways environment.
4355
"""
4456

4557
upload_dir: Required[str] = REQUIRED
4658
upload_interval: Required[int] = REQUIRED
47-
step_deviation_interval_seconds: int = 30 # Default to 30 seconds
59+
rolling_window_size: Sequence[int] = []
60+
jax_backend: Optional[str] = None
4861

4962
@classmethod
5063
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
@@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
5366
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
5467
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
5568
additionally take in following Tensorboard configs in the recorder_spec:
56-
- upload_dir: The directory to write Tensorboard data to.
57-
- upload_interval: The time interval in seconds at which to query and upload data
58-
to Tensorboard.
59-
- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
60-
uploads. Set to less than or equal to 0 to disable step deviation uploads.
69+
- upload_dir: The directory to write Tensorboard data to.
70+
- upload_interval: The time interval in seconds at which to query and upload data
71+
to Tensorboard.
72+
- rolling_window_size: Comma-separated list of integers representing rolling window
73+
sizes in seconds.
74+
- jax_backend: The type of jax backend.
6175
"""
6276
cfg: measurement.Recorder.Config = cls.default_config()
63-
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
64-
return cfg.instantiate()
77+
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
78+
if "upload_interval" in parsed_flags:
79+
parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"])
80+
if "rolling_window_size" in parsed_flags and isinstance(
81+
parsed_flags["rolling_window_size"], str
82+
):
83+
parsed_flags["rolling_window_size"] = [
84+
int(x) for x in parsed_flags["rolling_window_size"].split(",")
85+
]
86+
return maybe_set_config(cfg, **parsed_flags).instantiate()
6587

6688
def __init__(self, cfg):
6789
super().__init__(cfg)
68-
cfg: GoodputRecorder.Config = self.config
69-
self._recorder = None
70-
self._monitor = None
71-
72-
def record(self, event: measurement.Event, *args, **kwargs):
73-
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
90+
self._recorder: Optional[goodput.GoodputRecorder] = None
91+
self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None
92+
self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None
93+
self._job_name = cfg.name
94+
self._logger_name = f"goodput_logger_{cfg.name}"
95+
96+
@contextlib.contextmanager
97+
def record_event(self, event: measurement.Event, *args, **kwargs):
98+
"""Records a goodput event using a context manager."""
99+
# Lazily instantiate the recorder if it hasn't been already.
74100
if self._recorder is None:
75-
cfg: GoodputRecorder.Config = self.config
101+
if jax.process_index() == 0:
102+
logging.info("Lazily instantiating goodput recorder.")
76103
self._recorder = goodput.GoodputRecorder(
77-
job_name=cfg.name,
78-
logger_name=f"goodput_logger_{cfg.name}",
104+
job_name=self._job_name,
105+
logger_name=self._logger_name,
79106
logging_enabled=(jax.process_index() == 0),
80107
)
81108

82-
if event == measurement.Event.START_JOB:
83-
self._recorder.record_job_start_time(*args, **kwargs)
84-
elif event == measurement.Event.END_JOB:
85-
self._recorder.record_job_end_time(*args, **kwargs)
86-
elif event == measurement.Event.START_STEP:
87-
self._recorder.record_step_start_time(*args, **kwargs)
88-
elif event == measurement.Event.START_ACCELERATOR_INIT:
89-
self._recorder.record_tpu_init_start_time(*args, **kwargs)
90-
elif event == measurement.Event.END_ACCELERATOR_INIT:
91-
self._recorder.record_tpu_init_end_time(*args, **kwargs)
92-
elif event == measurement.Event.START_TRAINING_PREPARATION:
93-
self._recorder.record_training_preparation_start_time(*args, **kwargs)
94-
elif event == measurement.Event.END_TRAINING_PREPARATION:
95-
self._recorder.record_training_preparation_end_time(*args, **kwargs)
96-
elif event == measurement.Event.START_DATA_LOADING:
97-
self._recorder.record_data_loading_start_time(*args, **kwargs)
98-
elif event == measurement.Event.END_DATA_LOADING:
99-
self._recorder.record_data_loading_end_time(*args, **kwargs)
100-
elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT:
101-
self._recorder.record_custom_badput_event_start_time(*args, **kwargs)
102-
elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT:
103-
self._recorder.record_custom_badput_event_end_time(*args, **kwargs)
104-
else:
105-
logging.log_first_n(
106-
logging.WARNING,
107-
"Ignoring unknown event %s",
108-
1,
109-
event,
109+
start_method_name = f"record_{event.value}_start_time"
110+
end_method_name = f"record_{event.value}_end_time"
111+
112+
record_event_start = getattr(self._recorder, start_method_name, None)
113+
record_event_end = getattr(self._recorder, end_method_name, None)
114+
115+
try:
116+
if record_event_start:
117+
record_event_start(*args, **kwargs)
118+
except RuntimeError as e:
119+
logging.warning(
120+
"Failed to record start of event %s. Error: %s", event.value, e, exc_info=True
110121
)
111122

112-
def start_monitoring(self, *args, **kwargs):
113-
"""Starts Monitoring of Goodput.
123+
try:
124+
yield
125+
finally:
126+
try:
127+
if record_event_end:
128+
record_event_end(*args, **kwargs)
129+
except RuntimeError as e:
130+
logging.warning(
131+
"Failed to record end of event %s. Error: %s", event.value, e, exc_info=True
132+
)
133+
134+
@contextlib.contextmanager
135+
def maybe_monitor_goodput(self, *args, **kwargs):
136+
"""Monitor cumulative goodput if enabled.
114137
115138
Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
116-
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
117-
directory.
139+
Goodput, Badput, Step & Disruption Information at the upload_interval to the
140+
specified TensorBoard directory and Google Cloud Monitoring.
118141
Note: This function requires initialization of distributed JAX before it is called.
119142
If there are internal GCP errors from querying and uploading data, these will be
120143
logged without affecting the workload. GoodputMonitor logs will provide further
@@ -123,33 +146,62 @@ def start_monitoring(self, *args, **kwargs):
123146
Default behavior is to push metrics to Google Cloud Monitoring.
124147
This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
125148
"""
126-
cfg: GoodputRecorder.Config = self.config
127-
include_step_deviation = True
128-
if jax.process_index() == 0:
149+
if jax.process_index() != 0:
150+
yield
151+
return
152+
try:
129153
if self._monitor is None:
130-
if int(cfg.step_deviation_interval_seconds) <= 0:
131-
include_step_deviation = False
132-
133-
gcp_options = goodput_monitoring.GCPOptions(
134-
enable_gcp_goodput_metrics=True,
135-
enable_gcp_step_deviation_metrics=include_step_deviation,
136-
)
137154
self._monitor = goodput_monitoring.GoodputMonitor(
138-
job_name=cfg.name,
139-
logger_name=f"goodput_logger_{cfg.name}",
140-
tensorboard_dir=cfg.upload_dir,
141-
upload_interval=int(cfg.upload_interval),
155+
job_name=self._job_name,
156+
logger_name=self._logger_name,
157+
tensorboard_dir=self.config.upload_dir,
158+
upload_interval=self.config.upload_interval,
142159
monitoring_enabled=True,
160+
pathway_enabled=self.config.jax_backend == "proxy",
143161
include_badput_breakdown=True,
144-
include_step_deviation=include_step_deviation,
145-
step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds),
146-
gcp_options=gcp_options,
147162
)
148163

149164
self._monitor.start_goodput_uploader(*args, **kwargs)
150165
logging.info("Started Goodput upload to Tensorboard & GCM in the background!")
151-
if include_step_deviation:
152-
self._monitor.start_step_deviation_uploader(*args, **kwargs)
166+
yield
167+
finally:
168+
if self._monitor:
169+
self._monitor.stop_goodput_uploader()
170+
logging.info("Flushed final metrics and safe exited from Goodput monitoring.")
171+
172+
@contextlib.contextmanager
173+
def maybe_monitor_rolling_window_goodput(self):
174+
"""Monitor rolling window goodput if enabled."""
175+
if not self.config.rolling_window_size or jax.process_index() != 0:
176+
yield
177+
return
178+
try:
179+
if self._rolling_window_monitor is None:
180+
rolling_window_tensorboard_dir = os.path.join(
181+
self.config.upload_dir, f"rolling_window_{self.config.name}"
182+
)
183+
self._rolling_window_monitor = goodput_monitoring.GoodputMonitor(
184+
job_name=self._job_name,
185+
logger_name=self._logger_name,
186+
tensorboard_dir=rolling_window_tensorboard_dir,
187+
upload_interval=self.config.upload_interval,
188+
monitoring_enabled=True,
189+
pathway_enabled=self.config.jax_backend == "proxy",
190+
include_badput_breakdown=True,
191+
)
192+
self._rolling_window_monitor.start_rolling_window_goodput_uploader(
193+
self.config.rolling_window_size
194+
)
195+
logging.info("Started Rolling Window Goodput monitoring in the background!")
196+
yield
197+
finally:
198+
if self._rolling_window_monitor:
199+
self._rolling_window_monitor.stop_rolling_window_goodput_uploader()
153200
logging.info(
154-
"Started Step Deviation upload to Tensorboard & GCM in the background!"
201+
"Flushed final metrics and safe exited from Rolling Window Goodput monitoring."
155202
)
203+
204+
205+
def create_goodput_recorder(cfg: GoodputRecorder.Config):
206+
"""Factory method to create GoodputRecorder."""
207+
return GoodputRecorder(cfg)

0 commit comments

Comments
 (0)