22
33"""Measurement utils for GCP.
44
5+ For detailed documentation and advanced usage, please refer to:
6+ axlearn/docs/05-Goodput-Monitoring.md
7+
58 Example:
69
710 # Enable Goodput when launching an AXLearn training job
1316 --recorder_spec=name=my-run-with-goodput \
1417 --recorder_spec=upload_dir=my-output-directory/summaries \
1518 --recorder_spec=upload_interval=30 \
16- --recorder_spec=step_deviation_interval_seconds=30
19+ --recorder_spec=rolling_window_size=86400,259200,432000
1720
1821"""
1922
23+ import contextlib
24+ import os
25+ from typing import Optional , Sequence
26+
2027import jax
2128from absl import flags , logging
2229from ml_goodput_measurement import goodput
@@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config):
3845 Attributes:
3946 upload_dir: Directory to store metrics for the monitor.
4047 upload_interval: Time interval (seconds) for monitoring uploads.
41- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
42- uploads. -1 to disable step deviation uploads.
48+ See "How to Monitor Cumulative Goodput Metrics" in
49+ docs/05-Goodput-Monitoring.md for more details.
50+ rolling_window_size: A sequence of integers defining the rolling window sizes in
51+ seconds.
52+ See "How to Monitor Rolling Window Goodput Metrics" in
53+ docs/05-Goodput-Monitoring.md for more details.
54+ jax_backend: Jax backend type to infer Pathways environment.
4355 """
4456
4557 upload_dir : Required [str ] = REQUIRED
4658 upload_interval : Required [int ] = REQUIRED
47- step_deviation_interval_seconds : int = 30 # Default to 30 seconds
59+ rolling_window_size : Sequence [int ] = []
60+ jax_backend : Optional [str ] = None
4861
4962 @classmethod
5063 def from_flags (cls , fv : flags .FlagValues ) -> "GoodputRecorder" :
@@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
5366 `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
5467 corresponding to keys will be set to the corresponding values. A GoodputRecorder can
5568 additionally take in following Tensorboard configs in the recorder_spec:
56- - upload_dir: The directory to write Tensorboard data to.
57- - upload_interval: The time interval in seconds at which to query and upload data
58- to Tensorboard.
59- - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
60- uploads. Set to less than or equal to 0 to disable step deviation uploads.
69+ - upload_dir: The directory to write Tensorboard data to.
70+ - upload_interval: The time interval in seconds at which to query and upload data
71+ to Tensorboard.
72+ - rolling_window_size: Comma-separated list of integers representing rolling window
73+ sizes in seconds.
74+ - jax_backend: The type of jax backend.
6175 """
6276 cfg : measurement .Recorder .Config = cls .default_config ()
63- cfg = maybe_set_config (cfg , ** parse_kv_flags (fv .recorder_spec , delimiter = "=" ))
64- return cfg .instantiate ()
77+ parsed_flags = parse_kv_flags (fv .recorder_spec , delimiter = "=" )
78+ if "upload_interval" in parsed_flags :
79+ parsed_flags ["upload_interval" ] = int (parsed_flags ["upload_interval" ])
80+ if "rolling_window_size" in parsed_flags and isinstance (
81+ parsed_flags ["rolling_window_size" ], str
82+ ):
83+ parsed_flags ["rolling_window_size" ] = [
84+ int (x ) for x in parsed_flags ["rolling_window_size" ].split ("," )
85+ ]
86+ return maybe_set_config (cfg , ** parsed_flags ).instantiate ()
6587
6688 def __init__ (self , cfg ):
6789 super ().__init__ (cfg )
68- cfg : GoodputRecorder .Config = self .config
69- self ._recorder = None
70- self ._monitor = None
71-
72- def record (self , event : measurement .Event , * args , ** kwargs ):
73- # Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
90+ self ._recorder : Optional [goodput .GoodputRecorder ] = None
91+ self ._monitor : Optional [goodput_monitoring .GoodputMonitor ] = None
92+ self ._rolling_window_monitor : Optional [goodput_monitoring .GoodputMonitor ] = None
93+ self ._job_name = cfg .name
94+ self ._logger_name = f"goodput_logger_{ cfg .name } "
95+
96+ @contextlib .contextmanager
97+ def record_event (self , event : measurement .Event , * args , ** kwargs ):
98+ """Records a goodput event using a context manager."""
99+ # Lazily instantiate the recorder if it hasn't been already.
74100 if self ._recorder is None :
75- cfg : GoodputRecorder .Config = self .config
101+ if jax .process_index () == 0 :
102+ logging .info ("Lazily instantiating goodput recorder." )
76103 self ._recorder = goodput .GoodputRecorder (
77- job_name = cfg . name ,
78- logger_name = f"goodput_logger_ { cfg . name } " ,
104+ job_name = self . _job_name ,
105+ logger_name = self . _logger_name ,
79106 logging_enabled = (jax .process_index () == 0 ),
80107 )
81108
82- if event == measurement .Event .START_JOB :
83- self ._recorder .record_job_start_time (* args , ** kwargs )
84- elif event == measurement .Event .END_JOB :
85- self ._recorder .record_job_end_time (* args , ** kwargs )
86- elif event == measurement .Event .START_STEP :
87- self ._recorder .record_step_start_time (* args , ** kwargs )
88- elif event == measurement .Event .START_ACCELERATOR_INIT :
89- self ._recorder .record_tpu_init_start_time (* args , ** kwargs )
90- elif event == measurement .Event .END_ACCELERATOR_INIT :
91- self ._recorder .record_tpu_init_end_time (* args , ** kwargs )
92- elif event == measurement .Event .START_TRAINING_PREPARATION :
93- self ._recorder .record_training_preparation_start_time (* args , ** kwargs )
94- elif event == measurement .Event .END_TRAINING_PREPARATION :
95- self ._recorder .record_training_preparation_end_time (* args , ** kwargs )
96- elif event == measurement .Event .START_DATA_LOADING :
97- self ._recorder .record_data_loading_start_time (* args , ** kwargs )
98- elif event == measurement .Event .END_DATA_LOADING :
99- self ._recorder .record_data_loading_end_time (* args , ** kwargs )
100- elif event == measurement .Event .START_CUSTOM_BADPUT_EVENT :
101- self ._recorder .record_custom_badput_event_start_time (* args , ** kwargs )
102- elif event == measurement .Event .END_CUSTOM_BADPUT_EVENT :
103- self ._recorder .record_custom_badput_event_end_time (* args , ** kwargs )
104- else :
105- logging .log_first_n (
106- logging .WARNING ,
107- "Ignoring unknown event %s" ,
108- 1 ,
109- event ,
109+ start_method_name = f"record_{ event .value } _start_time"
110+ end_method_name = f"record_{ event .value } _end_time"
111+
112+ record_event_start = getattr (self ._recorder , start_method_name , None )
113+ record_event_end = getattr (self ._recorder , end_method_name , None )
114+
115+ try :
116+ if record_event_start :
117+ record_event_start (* args , ** kwargs )
118+ except RuntimeError as e :
119+ logging .warning (
120+ "Failed to record start of event %s. Error: %s" , event .value , e , exc_info = True
110121 )
111122
112- def start_monitoring (self , * args , ** kwargs ):
113- """Starts Monitoring of Goodput.
123+ try :
124+ yield
125+ finally :
126+ try :
127+ if record_event_end :
128+ record_event_end (* args , ** kwargs )
129+ except RuntimeError as e :
130+ logging .warning (
131+ "Failed to record end of event %s. Error: %s" , event .value , e , exc_info = True
132+ )
133+
134+ @contextlib .contextmanager
135+ def maybe_monitor_goodput (self , * args , ** kwargs ):
136+ """Monitor cumulative goodput if enabled.
114137
115138 Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
116- Goodput and Badput at the upload_interval and upload to the specified TensorBoard
117- directory.
139+ Goodput, Badput, Step & Disruption Information at the upload_interval to the
140+ specified TensorBoard directory and Google Cloud Monitoring .
118141 Note: This function requires initialization of distributed JAX before it is called.
119142 If there are internal GCP errors from querying and uploading data, these will be
120143 logged without affecting the workload. GoodputMonitor logs will provide further
@@ -123,33 +146,62 @@ def start_monitoring(self, *args, **kwargs):
123146 Default behavior is to push metrics to Google Cloud Monitoring.
124147 This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
125148 """
126- cfg : GoodputRecorder .Config = self .config
127- include_step_deviation = True
128- if jax .process_index () == 0 :
149+ if jax .process_index () != 0 :
150+ yield
151+ return
152+ try :
129153 if self ._monitor is None :
130- if int (cfg .step_deviation_interval_seconds ) <= 0 :
131- include_step_deviation = False
132-
133- gcp_options = goodput_monitoring .GCPOptions (
134- enable_gcp_goodput_metrics = True ,
135- enable_gcp_step_deviation_metrics = include_step_deviation ,
136- )
137154 self ._monitor = goodput_monitoring .GoodputMonitor (
138- job_name = cfg . name ,
139- logger_name = f"goodput_logger_ { cfg . name } " ,
140- tensorboard_dir = cfg .upload_dir ,
141- upload_interval = int ( cfg . upload_interval ) ,
155+ job_name = self . _job_name ,
156+ logger_name = self . _logger_name ,
157+ tensorboard_dir = self . config .upload_dir ,
158+ upload_interval = self . config . upload_interval ,
142159 monitoring_enabled = True ,
160+ pathway_enabled = self .config .jax_backend == "proxy" ,
143161 include_badput_breakdown = True ,
144- include_step_deviation = include_step_deviation ,
145- step_deviation_interval_seconds = int (cfg .step_deviation_interval_seconds ),
146- gcp_options = gcp_options ,
147162 )
148163
149164 self ._monitor .start_goodput_uploader (* args , ** kwargs )
150165 logging .info ("Started Goodput upload to Tensorboard & GCM in the background!" )
151- if include_step_deviation :
152- self ._monitor .start_step_deviation_uploader (* args , ** kwargs )
166+ yield
167+ finally :
168+ if self ._monitor :
169+ self ._monitor .stop_goodput_uploader ()
170+ logging .info ("Flushed final metrics and safe exited from Goodput monitoring." )
171+
172+ @contextlib .contextmanager
173+ def maybe_monitor_rolling_window_goodput (self ):
174+ """Monitor rolling window goodput if enabled."""
175+ if not self .config .rolling_window_size or jax .process_index () != 0 :
176+ yield
177+ return
178+ try :
179+ if self ._rolling_window_monitor is None :
180+ rolling_window_tensorboard_dir = os .path .join (
181+ self .config .upload_dir , f"rolling_window_{ self .config .name } "
182+ )
183+ self ._rolling_window_monitor = goodput_monitoring .GoodputMonitor (
184+ job_name = self ._job_name ,
185+ logger_name = self ._logger_name ,
186+ tensorboard_dir = rolling_window_tensorboard_dir ,
187+ upload_interval = self .config .upload_interval ,
188+ monitoring_enabled = True ,
189+ pathway_enabled = self .config .jax_backend == "proxy" ,
190+ include_badput_breakdown = True ,
191+ )
192+ self ._rolling_window_monitor .start_rolling_window_goodput_uploader (
193+ self .config .rolling_window_size
194+ )
195+ logging .info ("Started Rolling Window Goodput monitoring in the background!" )
196+ yield
197+ finally :
198+ if self ._rolling_window_monitor :
199+ self ._rolling_window_monitor .stop_rolling_window_goodput_uploader ()
153200 logging .info (
154- "Started Step Deviation upload to Tensorboard & GCM in the background! "
201+ "Flushed final metrics and safe exited from Rolling Window Goodput monitoring. "
155202 )
203+
204+
205+ def create_goodput_recorder (cfg : GoodputRecorder .Config ):
206+ """Factory method to create GoodputRecorder."""
207+ return GoodputRecorder (cfg )
0 commit comments