Skip to content

Commit 563f9e6

Browse files
dipannita08SujeethJinesh
authored andcommitted
Integrate AXLearn with Goodput v12
1 parent a96aeb7 commit 563f9e6

File tree

3 files changed

+133
-78
lines changed

3 files changed

+133
-78
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
Example:
66
7-
# Enable Goodput when launching an AxLearn training job
7+
# Enable Goodput when launching an AXLearn training job
88
axlearn gcp launch run --instance_type=tpu-v5litepod-16 \
99
--bundler_type=artifactregistry --bundler_spec=image=tpu \
1010
--bundler_spec=dockerfile=Dockerfile \
@@ -13,7 +13,6 @@
1313
--recorder_spec=name=my-run-with-goodput \
1414
--recorder_spec=upload_dir=my-output-directory/summaries \
1515
--recorder_spec=upload_interval=30 \
16-
--recorder_spec=enable_rolling_window_goodput_monitoring=True \
1716
--recorder_spec=rolling_window_size=86400,259200,432000
1817
1918
"""
@@ -43,22 +42,13 @@ class Config(measurement.Recorder.Config):
4342
Attributes:
4443
upload_dir: Directory to store metrics for the monitor.
4544
upload_interval: Time interval (seconds) for monitoring uploads.
46-
enable_gcp_goodput_metrics: Whether to upload metrics to Google Cloud Monitoring.
47-
enable_pathways_goodput: Whether to enable goodput calculations specific to Pathways.
48-
include_badput_breakdown: Whether to include a detailed breakdown of badput sources.
49-
enable_rolling_window_goodput_monitoring: Enables goodput/badput monitoring over
50-
rolling time windows.
5145
rolling_window_size: A sequence of integers defining the rolling window sizes in
5246
seconds.
5347
"""
5448

5549
upload_dir: Required[str] = REQUIRED
5650
upload_interval: Required[int] = REQUIRED
57-
enable_gcp_goodput_metrics: bool = True
58-
enable_pathways_goodput: bool = False
59-
include_badput_breakdown: bool = True
60-
enable_rolling_window_goodput_monitoring: bool = False
61-
rolling_window_size: Sequence[int] = ()
51+
rolling_window_size: Sequence[int] = []
6252

6353
@classmethod
6454
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
@@ -70,15 +60,8 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
7060
- upload_dir: The directory to write Tensorboard data to.
7161
- upload_interval: The time interval in seconds at which to query and upload data
7262
to Tensorboard.
73-
- enable_rolling_window_goodput_monitoring: Whether to enable rolling window Goodput
74-
monitoring.
7563
- rolling_window_size: Comma-separated list of integers representing rolling window
7664
sizes in seconds.
77-
- enable_gcp_goodput_metrics: Whether to upload Goodput metrics to GCM.
78-
- enable_pathways_goodput: Whether to enable Pathways-specific Goodput
79-
calculations.
80-
- include_badput_breakdown: Whether to include a detailed breakdown of
81-
badput events in the monitoring.
8265
"""
8366
cfg: measurement.Recorder.Config = cls.default_config()
8467
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
@@ -90,8 +73,7 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
9073
parsed_flags["rolling_window_size"] = [
9174
int(x) for x in parsed_flags["rolling_window_size"].split(",")
9275
]
93-
cfg = maybe_set_config(cfg, **parsed_flags)
94-
return cfg.instantiate()
76+
return maybe_set_config(cfg, **parsed_flags).instantiate()
9577

9678
def __init__(self, cfg):
9779
super().__init__(cfg)
@@ -121,9 +103,9 @@ def record_event(self, event: measurement.Event, *args, **kwargs):
121103
try:
122104
if record_event_start:
123105
record_event_start(*args, **kwargs)
124-
except (TypeError, ValueError, RuntimeError) as e:
106+
except RuntimeError as e:
125107
logging.warning(
126-
"Failed to record start of event %s. Error: %s", event.name, e, exc_info=True
108+
"Failed to record start of event %s. Error: %s", event.value, e, exc_info=True
127109
)
128110

129111
try:
@@ -132,9 +114,9 @@ def record_event(self, event: measurement.Event, *args, **kwargs):
132114
try:
133115
if record_event_end:
134116
record_event_end(*args, **kwargs)
135-
except (TypeError, ValueError, RuntimeError) as e:
117+
except RuntimeError as e:
136118
logging.warning(
137-
"Failed to record end of event %s. Error: %s", event.name, e, exc_info=True
119+
"Failed to record end of event %s. Error: %s", event.value, e, exc_info=True
138120
)
139121

140122
@contextlib.contextmanager
@@ -157,18 +139,17 @@ def maybe_monitor_goodput(self, *args, **kwargs):
157139
return
158140
try:
159141
if self._monitor is None:
160-
gcp_options = goodput_monitoring.GCPOptions(
161-
enable_gcp_goodput_metrics=self.config.enable_gcp_goodput_metrics,
142+
pathways_enabled = (
143+
hasattr(flags.FLAGS, "jax_backend") and flags.FLAGS.jax_backend == "proxy"
162144
)
163145
self._monitor = goodput_monitoring.GoodputMonitor(
164146
job_name=self.config.name,
165147
logger_name=f"goodput_logger_{self.config.name}",
166148
tensorboard_dir=self.config.upload_dir,
167149
upload_interval=self.config.upload_interval,
168150
monitoring_enabled=True,
169-
pathway_enabled=self.config.enable_pathways_goodput,
170-
include_badput_breakdown=self.config.include_badput_breakdown,
171-
gcp_options=gcp_options,
151+
pathway_enabled=pathways_enabled,
152+
include_badput_breakdown=True,
172153
)
173154

174155
self._monitor.start_goodput_uploader(*args, **kwargs)
@@ -182,11 +163,14 @@ def maybe_monitor_goodput(self, *args, **kwargs):
182163
@contextlib.contextmanager
183164
def maybe_monitor_rolling_window_goodput(self):
184165
"""Monitor rolling window goodput if enabled."""
185-
if not self.config.enable_rolling_window_goodput_monitoring or jax.process_index() != 0:
166+
if not self.config.rolling_window_size or jax.process_index() != 0:
186167
yield
187168
return
188169
try:
189170
if self._rolling_window_monitor is None:
171+
pathways_enabled = (
172+
hasattr(flags.FLAGS, "jax_backend") and flags.FLAGS.jax_backend == "proxy"
173+
)
190174
rolling_window_tensorboard_dir = os.path.join(
191175
self.config.upload_dir, f"rolling_window_{self.config.name}"
192176
)
@@ -196,7 +180,7 @@ def maybe_monitor_rolling_window_goodput(self):
196180
tensorboard_dir=rolling_window_tensorboard_dir,
197181
upload_interval=self.config.upload_interval,
198182
monitoring_enabled=True,
199-
pathway_enabled=self.config.enable_pathways_goodput,
183+
pathway_enabled=pathways_enabled,
200184
include_badput_breakdown=True,
201185
)
202186
self._rolling_window_monitor.start_rolling_window_goodput_uploader(

axlearn/cloud/gcp/measurement_test.py

Lines changed: 113 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,59 +16,90 @@
1616
class GoodputRecorderTest(parameterized.TestCase):
1717
"""Tests GoodputRecorder."""
1818

19-
def test_from_flags_with_spec(self):
20-
"""Tests that flags are correctly parsed."""
21-
fv = flags.FlagValues()
22-
measurement.define_flags(flag_values=fv)
23-
fv.set_default(
24-
"recorder_spec",
25-
[
19+
def setUp(self):
20+
super().setUp()
21+
self.mock_flags = mock.MagicMock(spec=flags.FlagValues)
22+
self.mock_flags.jax_backend = None
23+
self.patch_flags_global = mock.patch.object(flags, "FLAGS", new=self.mock_flags)
24+
self.patch_flags_global.start()
25+
26+
def tearDown(self):
27+
super().tearDown()
28+
self.patch_flags_global.stop()
29+
30+
@parameterized.parameters(
31+
dict(
32+
recorder_spec=[
33+
"name=test-name",
34+
"upload_dir=/test/path",
35+
"upload_interval=15",
36+
],
37+
expected_rolling_window_size=[],
38+
),
39+
dict(
40+
recorder_spec=[
2641
"name=test-name",
2742
"upload_dir=/test/path",
2843
"upload_interval=15",
29-
"enable_rolling_window_goodput_monitoring=True",
3044
"rolling_window_size=1,2,3",
3145
],
32-
)
33-
fv.mark_as_parsed()
34-
recorder = GoodputRecorder.from_flags(fv)
46+
expected_rolling_window_size=[1, 2, 3],
47+
),
48+
)
49+
def test_from_flags(
50+
self,
51+
recorder_spec,
52+
expected_rolling_window_size,
53+
):
54+
"""Tests that flags are correctly parsed into the config."""
55+
mock_fv = mock.MagicMock(spec=flags.FlagValues)
56+
mock_fv.recorder_spec = recorder_spec
57+
mock_fv.jax_backend = "tpu"
58+
59+
recorder = GoodputRecorder.from_flags(mock_fv)
60+
3561
self.assertEqual("test-name", recorder.config.name)
3662
self.assertEqual("/test/path", recorder.config.upload_dir)
3763
self.assertEqual(15, recorder.config.upload_interval)
38-
self.assertTrue(recorder.config.enable_rolling_window_goodput_monitoring)
39-
self.assertEqual([1, 2, 3], recorder.config.rolling_window_size)
64+
self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size)
4065

4166
def test_from_flags_missing_required(self):
4267
"""Tests that missing required flags raise an error."""
43-
fv = flags.FlagValues()
44-
measurement.define_flags(flag_values=fv)
45-
fv.set_default("recorder_spec", ["name=test-name"]) # Missing upload_dir/interval
46-
fv.mark_as_parsed()
68+
mock_fv = mock.MagicMock(spec=flags.FlagValues)
69+
mock_fv.recorder_spec = ["name=test-name"] # Missing upload_dir/interval
70+
mock_fv.jax_backend = "tpu"
4771
with self.assertRaisesRegex(RequiredFieldMissingError, "upload_dir"):
48-
GoodputRecorder.from_flags(fv)
72+
GoodputRecorder.from_flags(mock_fv)
4973

5074
@mock.patch("jax.process_index", return_value=0)
5175
def test_record_event_context_manager(self, _):
5276
"""Tests the record_event context manager."""
53-
recorder = GoodputRecorder(GoodputRecorder.default_config().set(name="test"))
54-
with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder:
55-
mock_instance = mock_recorder.return_value
77+
cfg = GoodputRecorder.default_config().set(
78+
name="test",
79+
upload_dir="/tmp/test",
80+
upload_interval=1,
81+
)
82+
recorder = GoodputRecorder(cfg)
83+
with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls:
84+
mock_instance = mock_recorder_cls.return_value
5685
with recorder.record_event(measurement.Event.JOB):
5786
pass
58-
mock_recorder.assert_called_once()
87+
mock_recorder_cls.assert_called_once()
5988
mock_instance.record_job_start_time.assert_called_once()
6089
mock_instance.record_job_end_time.assert_called_once()
6190

91+
@parameterized.parameters(
92+
dict(is_pathways_job=False, mock_jax_backend="tpu"),
93+
dict(is_pathways_job=True, mock_jax_backend="proxy"),
94+
)
6295
@mock.patch("jax.process_index", return_value=0)
63-
def test_maybe_monitor_goodput(self, _):
96+
def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend):
6497
"""Tests the maybe_monitor_goodput context manager."""
98+
self.mock_flags.jax_backend = mock_jax_backend
6599
cfg = GoodputRecorder.default_config().set(
66100
name="test-monitor",
67101
upload_dir="/test",
68102
upload_interval=30,
69-
enable_gcp_goodput_metrics=True,
70-
enable_pathways_goodput=False,
71-
include_badput_breakdown=True,
72103
)
73104
recorder = GoodputRecorder(cfg)
74105

@@ -84,47 +115,82 @@ def test_maybe_monitor_goodput(self, _):
84115
tensorboard_dir="/test",
85116
upload_interval=30,
86117
monitoring_enabled=True,
87-
pathway_enabled=False,
118+
pathway_enabled=is_pathways_job,
88119
include_badput_breakdown=True,
89-
gcp_options=mock.ANY,
90120
)
91-
# Verify the start and stop methods were called.
92121
mock_monitor_instance.start_goodput_uploader.assert_called_once()
93122
mock_monitor_instance.stop_goodput_uploader.assert_called_once()
94123

124+
@parameterized.parameters(
125+
dict(
126+
is_rolling_window_enabled=True,
127+
rolling_window_size=[10, 20],
128+
is_pathways_job=False,
129+
mock_jax_backend="tpu",
130+
),
131+
dict(
132+
is_rolling_window_enabled=False,
133+
rolling_window_size=[],
134+
is_pathways_job=False,
135+
mock_jax_backend="tpu",
136+
),
137+
dict(
138+
is_rolling_window_enabled=True,
139+
rolling_window_size=[50],
140+
is_pathways_job=True,
141+
mock_jax_backend="proxy",
142+
),
143+
)
95144
@mock.patch("jax.process_index", return_value=0)
96-
def test_maybe_monitor_rolling_window(self, _):
145+
def test_maybe_monitor_rolling_window(
146+
self,
147+
mock_process_index,
148+
is_rolling_window_enabled,
149+
rolling_window_size,
150+
is_pathways_job,
151+
mock_jax_backend,
152+
): # pylint: disable=unused-argument
97153
"""Tests the rolling window monitoring context manager."""
154+
self.mock_flags.jax_backend = mock_jax_backend
98155
cfg = GoodputRecorder.default_config().set(
99156
name="test-rolling",
100157
upload_dir="/test",
101158
upload_interval=30,
102-
enable_rolling_window_goodput_monitoring=True,
103-
rolling_window_size=[10, 20],
159+
rolling_window_size=rolling_window_size,
104160
)
105161
recorder = GoodputRecorder(cfg)
106162

107163
with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
108164
mock_monitor_instance = mock_monitor_cls.return_value
165+
if not is_rolling_window_enabled:
166+
with recorder.maybe_monitor_rolling_window_goodput():
167+
pass
168+
mock_monitor_cls.assert_not_called()
169+
return
109170
with recorder.maybe_monitor_rolling_window_goodput():
110171
pass
111172

112-
# Verify that GoodputMonitor was instantiated for rolling window.
113-
mock_monitor_cls.assert_called_once()
114-
self.assertEqual(
115-
"/test/rolling_window_test-rolling",
116-
mock_monitor_cls.call_args.kwargs["tensorboard_dir"],
173+
mock_monitor_cls.assert_called_once_with(
174+
job_name="test-rolling",
175+
logger_name="goodput_logger_test-rolling",
176+
tensorboard_dir="/test/rolling_window_test-rolling",
177+
upload_interval=30,
178+
monitoring_enabled=True,
179+
pathway_enabled=is_pathways_job,
180+
include_badput_breakdown=True,
117181
)
118182

119-
# Verify the correct the start and stop methods were called.
120-
mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with([10, 20])
183+
mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with(
184+
rolling_window_size
185+
)
121186
mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once()
122187

123188
@mock.patch("jax.process_index", return_value=1)
124189
def test_non_zero_process_index_skips_monitoring(
125190
self, mock_process_index
126191
): # pylint: disable=unused-argument
127192
"""Tests that monitoring is skipped on non-zero process indices."""
193+
self.mock_flags.jax_backend = "tpu"
128194
cfg = GoodputRecorder.default_config().set(
129195
name="test", upload_dir="/test", upload_interval=30
130196
)
@@ -136,8 +202,13 @@ def test_non_zero_process_index_skips_monitoring(
136202
pass
137203
mock_monitor_cls.assert_not_called()
138204

139-
# Test maybe_monitor_rolling_window_goodput
140-
recorder.config.enable_rolling_window_goodput_monitoring = True
141-
with recorder.maybe_monitor_rolling_window_goodput():
205+
cfg_rolling = GoodputRecorder.default_config().set(
206+
name="test-rolling-skip",
207+
upload_dir="/test",
208+
upload_interval=30,
209+
rolling_window_size=[10, 20],
210+
)
211+
recorder_rolling = GoodputRecorder(cfg_rolling)
212+
with recorder_rolling.maybe_monitor_rolling_window_goodput():
142213
pass
143214
mock_monitor_cls.assert_not_called()

docs/05-Goodput-Monitoring.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ML Goodput Monitoring
2-
AxLearn supports automatic measurement and upload of a wide range of workload
2+
AXLearn supports automatic measurement and upload of a wide range of workload
33
metrics using the **ML Goodput Measurement** library. This includes:
44
* **Goodput** and **Badput Breakdown**
55
* **Step Metrics** (Ideal Step Time, Step Time Deviation, Last Productive Step etc.)
@@ -97,7 +97,7 @@ Goodput/Badput metrics of a previous workload along with your current workload.
9797

9898
### How to Monitor Cumulative Goodput Metrics
9999

100-
To enable Goodput recording and monitoring on AxLearn, follow the example below.
100+
To enable Goodput recording and monitoring on AXLearn, follow the example below.
101101

102102

103103
```bash
@@ -133,11 +133,11 @@ axlearn gcp launch run --instance_type=tpu-v5litepod-16 \
133133
### Visualize on Tensorboard
134134

135135
1. Requires packages: `tensorboard-plugin-profile`, `tensorflow` and `tensorboard`.
136-
2. Use the Tensorboard URL on AxLearn logs to view all metrics in one location.
136+
2. Use the Tensorboard URL on AXLearn logs to view all metrics in one location.
137137

138138
### Enabling Google Cloud Monitoring
139139

140-
By default, when Goodput monitoring is enabled via the recorder, AxLearn automatically pushes metrics to Google Cloud Monitoring.
140+
By default, when Goodput monitoring is enabled via the recorder, AXLearn automatically pushes metrics to Google Cloud Monitoring.
141141

142142
- **Cumulative Metrics** are enabled by default when you specify the `recorder_type`.
143143
To disable this, you would need to set `enable_gcp_goodput_metrics` to `False` in

0 commit comments

Comments
 (0)