Skip to content

Commit a96aeb7

Browse files
dipannita08SujeethJinesh
authored andcommitted
Add workload hang monitoring & rolling window goodput support
1 parent 684850d commit a96aeb7

File tree

9 files changed

+658
-516
lines changed

9 files changed

+658
-516
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 132 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
Example:
66
7-
# Enable Goodput when launching an AXLearn training job
7+
# Enable Goodput when launching an AxLearn training job
88
axlearn gcp launch run --instance_type=tpu-v5litepod-16 \
99
--bundler_type=artifactregistry --bundler_spec=image=tpu \
1010
--bundler_spec=dockerfile=Dockerfile \
@@ -13,10 +13,15 @@
1313
--recorder_spec=name=my-run-with-goodput \
1414
--recorder_spec=upload_dir=my-output-directory/summaries \
1515
--recorder_spec=upload_interval=30 \
16-
--recorder_spec=step_deviation_interval_seconds=30
16+
--recorder_spec=enable_rolling_window_goodput_monitoring=True \
17+
--recorder_spec=rolling_window_size=86400,259200,432000
1718
1819
"""
1920

21+
import contextlib
22+
import os
23+
from typing import Optional, Sequence
24+
2025
import jax
2126
from absl import flags, logging
2227
from ml_goodput_measurement import goodput
@@ -38,13 +43,22 @@ class Config(measurement.Recorder.Config):
3843
Attributes:
3944
upload_dir: Directory to store metrics for the monitor.
4045
upload_interval: Time interval (seconds) for monitoring uploads.
41-
step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
42-
uploads. -1 to disable step deviation uploads.
46+
enable_gcp_goodput_metrics: Whether to upload metrics to Google Cloud Monitoring.
47+
enable_pathways_goodput: Whether to enable goodput calculations specific to Pathways.
48+
include_badput_breakdown: Whether to include a detailed breakdown of badput sources.
49+
enable_rolling_window_goodput_monitoring: Enables goodput/badput monitoring over
50+
rolling time windows.
51+
rolling_window_size: A sequence of integers defining the rolling window sizes in
52+
seconds.
4353
"""
4454

4555
upload_dir: Required[str] = REQUIRED
4656
upload_interval: Required[int] = REQUIRED
47-
step_deviation_interval_seconds: int = 30 # Default to 30 seconds
57+
enable_gcp_goodput_metrics: bool = True
58+
enable_pathways_goodput: bool = False
59+
include_badput_breakdown: bool = True
60+
enable_rolling_window_goodput_monitoring: bool = False
61+
rolling_window_size: Sequence[int] = ()
4862

4963
@classmethod
5064
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
@@ -53,68 +67,83 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
5367
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
5468
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
5569
additionally take in following Tensorboard configs in the recorder_spec:
56-
- upload_dir: The directory to write Tensorboard data to.
57-
- upload_interval: The time interval in seconds at which to query and upload data
58-
to Tensorboard.
59-
- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
60-
uploads. Set to less than or equal to 0 to disable step deviation uploads.
70+
- upload_dir: The directory to write Tensorboard data to.
71+
- upload_interval: The time interval in seconds at which to query and upload data
72+
to Tensorboard.
73+
- enable_rolling_window_goodput_monitoring: Whether to enable rolling window Goodput
74+
monitoring.
75+
- rolling_window_size: Comma-separated list of integers representing rolling window
76+
sizes in seconds.
77+
- enable_gcp_goodput_metrics: Whether to upload Goodput metrics to GCM.
78+
- enable_pathways_goodput: Whether to enable Pathways-specific Goodput
79+
calculations.
80+
- include_badput_breakdown: Whether to include a detailed breakdown of
81+
badput events in the monitoring.
6182
"""
6283
cfg: measurement.Recorder.Config = cls.default_config()
63-
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
84+
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
85+
if "upload_interval" in parsed_flags:
86+
parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"])
87+
if "rolling_window_size" in parsed_flags and isinstance(
88+
parsed_flags["rolling_window_size"], str
89+
):
90+
parsed_flags["rolling_window_size"] = [
91+
int(x) for x in parsed_flags["rolling_window_size"].split(",")
92+
]
93+
cfg = maybe_set_config(cfg, **parsed_flags)
6494
return cfg.instantiate()
6595

6696
def __init__(self, cfg):
6797
super().__init__(cfg)
68-
cfg: GoodputRecorder.Config = self.config
69-
self._recorder = None
70-
self._monitor = None
71-
72-
def record(self, event: measurement.Event, *args, **kwargs):
73-
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
98+
self._recorder: Optional[goodput.GoodputRecorder] = None
99+
self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None
100+
self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None
101+
102+
@contextlib.contextmanager
103+
def record_event(self, event: measurement.Event, *args, **kwargs):
104+
"""Records a goodput event using a context manager."""
105+
# Lazily instantiate the recorder if it hasn't been already.
74106
if self._recorder is None:
75-
cfg: GoodputRecorder.Config = self.config
107+
if jax.process_index() == 0:
108+
logging.info("Lazily instantiating goodput recorder.")
76109
self._recorder = goodput.GoodputRecorder(
77-
job_name=cfg.name,
78-
logger_name=f"goodput_logger_{cfg.name}",
110+
job_name=self.config.name,
111+
logger_name=f"goodput_logger_{self.config.name}",
79112
logging_enabled=(jax.process_index() == 0),
80113
)
81114

82-
if event == measurement.Event.START_JOB:
83-
self._recorder.record_job_start_time(*args, **kwargs)
84-
elif event == measurement.Event.END_JOB:
85-
self._recorder.record_job_end_time(*args, **kwargs)
86-
elif event == measurement.Event.START_STEP:
87-
self._recorder.record_step_start_time(*args, **kwargs)
88-
elif event == measurement.Event.START_ACCELERATOR_INIT:
89-
self._recorder.record_tpu_init_start_time(*args, **kwargs)
90-
elif event == measurement.Event.END_ACCELERATOR_INIT:
91-
self._recorder.record_tpu_init_end_time(*args, **kwargs)
92-
elif event == measurement.Event.START_TRAINING_PREPARATION:
93-
self._recorder.record_training_preparation_start_time(*args, **kwargs)
94-
elif event == measurement.Event.END_TRAINING_PREPARATION:
95-
self._recorder.record_training_preparation_end_time(*args, **kwargs)
96-
elif event == measurement.Event.START_DATA_LOADING:
97-
self._recorder.record_data_loading_start_time(*args, **kwargs)
98-
elif event == measurement.Event.END_DATA_LOADING:
99-
self._recorder.record_data_loading_end_time(*args, **kwargs)
100-
elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT:
101-
self._recorder.record_custom_badput_event_start_time(*args, **kwargs)
102-
elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT:
103-
self._recorder.record_custom_badput_event_end_time(*args, **kwargs)
104-
else:
105-
logging.log_first_n(
106-
logging.WARNING,
107-
"Ignoring unknown event %s",
108-
1,
109-
event,
115+
start_method_name = f"record_{event.value}_start_time"
116+
end_method_name = f"record_{event.value}_end_time"
117+
118+
record_event_start = getattr(self._recorder, start_method_name, None)
119+
record_event_end = getattr(self._recorder, end_method_name, None)
120+
121+
try:
122+
if record_event_start:
123+
record_event_start(*args, **kwargs)
124+
except (TypeError, ValueError, RuntimeError) as e:
125+
logging.warning(
126+
"Failed to record start of event %s. Error: %s", event.name, e, exc_info=True
110127
)
111128

112-
def start_monitoring(self, *args, **kwargs):
113-
"""Starts Monitoring of Goodput.
129+
try:
130+
yield
131+
finally:
132+
try:
133+
if record_event_end:
134+
record_event_end(*args, **kwargs)
135+
except (TypeError, ValueError, RuntimeError) as e:
136+
logging.warning(
137+
"Failed to record end of event %s. Error: %s", event.name, e, exc_info=True
138+
)
139+
140+
@contextlib.contextmanager
141+
def maybe_monitor_goodput(self, *args, **kwargs):
142+
"""Monitor cumulative goodput if enabled.
114143
115144
Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
116-
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
117-
directory.
145+
Goodput, Badput, Step & Disruption Information at the upload_interval to the
146+
specified TensorBoard directory and Google Cloud Monitoring.
118147
Note: This function requires initialization of distributed JAX before it is called.
119148
If there are internal GCP errors from querying and uploading data, these will be
120149
logged without affecting the workload. GoodputMonitor logs will provide further
@@ -123,33 +152,66 @@ def start_monitoring(self, *args, **kwargs):
123152
Default behavior is to push metrics to Google Cloud Monitoring.
124153
This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
125154
"""
126-
cfg: GoodputRecorder.Config = self.config
127-
include_step_deviation = True
128-
if jax.process_index() == 0:
155+
if jax.process_index() != 0:
156+
yield
157+
return
158+
try:
129159
if self._monitor is None:
130-
if int(cfg.step_deviation_interval_seconds) <= 0:
131-
include_step_deviation = False
132-
133160
gcp_options = goodput_monitoring.GCPOptions(
134-
enable_gcp_goodput_metrics=True,
135-
enable_gcp_step_deviation_metrics=include_step_deviation,
161+
enable_gcp_goodput_metrics=self.config.enable_gcp_goodput_metrics,
136162
)
137163
self._monitor = goodput_monitoring.GoodputMonitor(
138-
job_name=cfg.name,
139-
logger_name=f"goodput_logger_{cfg.name}",
140-
tensorboard_dir=cfg.upload_dir,
141-
upload_interval=int(cfg.upload_interval),
164+
job_name=self.config.name,
165+
logger_name=f"goodput_logger_{self.config.name}",
166+
tensorboard_dir=self.config.upload_dir,
167+
upload_interval=self.config.upload_interval,
142168
monitoring_enabled=True,
143-
include_badput_breakdown=True,
144-
include_step_deviation=include_step_deviation,
145-
step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds),
169+
pathway_enabled=self.config.enable_pathways_goodput,
170+
include_badput_breakdown=self.config.include_badput_breakdown,
146171
gcp_options=gcp_options,
147172
)
148173

149174
self._monitor.start_goodput_uploader(*args, **kwargs)
150175
logging.info("Started Goodput upload to Tensorboard & GCM in the background!")
151-
if include_step_deviation:
152-
self._monitor.start_step_deviation_uploader(*args, **kwargs)
176+
yield
177+
finally:
178+
if self._monitor:
179+
self._monitor.stop_goodput_uploader()
180+
logging.info("Flushed final metrics and safe exited from Goodput monitoring.")
181+
182+
@contextlib.contextmanager
183+
def maybe_monitor_rolling_window_goodput(self):
184+
"""Monitor rolling window goodput if enabled."""
185+
if not self.config.enable_rolling_window_goodput_monitoring or jax.process_index() != 0:
186+
yield
187+
return
188+
try:
189+
if self._rolling_window_monitor is None:
190+
rolling_window_tensorboard_dir = os.path.join(
191+
self.config.upload_dir, f"rolling_window_{self.config.name}"
192+
)
193+
self._rolling_window_monitor = goodput_monitoring.GoodputMonitor(
194+
job_name=self.config.name,
195+
logger_name=f"goodput_logger_{self.config.name}",
196+
tensorboard_dir=rolling_window_tensorboard_dir,
197+
upload_interval=self.config.upload_interval,
198+
monitoring_enabled=True,
199+
pathway_enabled=self.config.enable_pathways_goodput,
200+
include_badput_breakdown=True,
201+
)
202+
self._rolling_window_monitor.start_rolling_window_goodput_uploader(
203+
self.config.rolling_window_size
204+
)
205+
logging.info("Started Rolling Window Goodput monitoring in the background!")
206+
yield
207+
finally:
208+
if self._rolling_window_monitor:
209+
self._rolling_window_monitor.stop_rolling_window_goodput_uploader()
153210
logging.info(
154-
"Started Step Deviation upload to Tensorboard & GCM in the background!"
211+
"Flushed final metrics and safe exited from Rolling Window Goodput monitoring."
155212
)
213+
214+
215+
def create_goodput_recorder(cfg: GoodputRecorder.Config):
216+
"""Factory method to create GoodputRecorder."""
217+
return GoodputRecorder(cfg)

0 commit comments

Comments
 (0)