44
55 Example:
66
7- # Enable Goodput when launching an AXLearn training job
7+ # Enable Goodput when launching an AxLearn training job
88 axlearn gcp launch run --instance_type=tpu-v5litepod-16 \
99 --bundler_type=artifactregistry --bundler_spec=image=tpu \
1010 --bundler_spec=dockerfile=Dockerfile \
1313 --recorder_spec=name=my-run-with-goodput \
1414 --recorder_spec=upload_dir=my-output-directory/summaries \
1515 --recorder_spec=upload_interval=30 \
16- --recorder_spec=step_deviation_interval_seconds=30
16+ --recorder_spec=enable_rolling_window_goodput_monitoring=True \
17+ --recorder_spec=rolling_window_size=86400,259200,432000
1718
1819"""
1920
21+ import contextlib
22+ import os
23+ from typing import Optional , Sequence
24+
2025import jax
2126from absl import flags , logging
2227from ml_goodput_measurement import goodput
@@ -38,13 +43,22 @@ class Config(measurement.Recorder.Config):
3843 Attributes:
3944 upload_dir: Directory to store metrics for the monitor.
4045 upload_interval: Time interval (seconds) for monitoring uploads.
41- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
42- uploads. -1 to disable step deviation uploads.
46+ enable_gcp_goodput_metrics: Whether to upload metrics to Google Cloud Monitoring.
47+ enable_pathways_goodput: Whether to enable goodput calculations specific to Pathways.
48+ include_badput_breakdown: Whether to include a detailed breakdown of badput sources.
49+ enable_rolling_window_goodput_monitoring: Enables goodput/badput monitoring over
50+ rolling time windows.
51+ rolling_window_size: A sequence of integers defining the rolling window sizes in
52+ seconds.
4353 """
4454
4555 upload_dir : Required [str ] = REQUIRED
4656 upload_interval : Required [int ] = REQUIRED
47- step_deviation_interval_seconds : int = 30 # Default to 30 seconds
57+ enable_gcp_goodput_metrics : bool = True
58+ enable_pathways_goodput : bool = False
59+ include_badput_breakdown : bool = True
60+ enable_rolling_window_goodput_monitoring : bool = False
61+ rolling_window_size : Sequence [int ] = ()
4862
4963 @classmethod
5064 def from_flags (cls , fv : flags .FlagValues ) -> "GoodputRecorder" :
@@ -53,68 +67,83 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
5367 `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
5468 corresponding to keys will be set to the corresponding values. A GoodputRecorder can
5569 additionally take in following Tensorboard configs in the recorder_spec:
56- - upload_dir: The directory to write Tensorboard data to.
57- - upload_interval: The time interval in seconds at which to query and upload data
58- to Tensorboard.
59- - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
60- uploads. Set to less than or equal to 0 to disable step deviation uploads.
70+ - upload_dir: The directory to write Tensorboard data to.
71+ - upload_interval: The time interval in seconds at which to query and upload data
72+ to Tensorboard.
73+ - enable_rolling_window_goodput_monitoring: Whether to enable rolling window Goodput
74+ monitoring.
75+ - rolling_window_size: Comma-separated list of integers representing rolling window
76+ sizes in seconds.
77+ - enable_gcp_goodput_metrics: Whether to upload Goodput metrics to GCM.
78+ - enable_pathways_goodput: Whether to enable Pathways-specific Goodput
79+ calculations.
80+ - include_badput_breakdown: Whether to include a detailed breakdown of
81+ badput events in the monitoring.
6182 """
6283 cfg : measurement .Recorder .Config = cls .default_config ()
63- cfg = maybe_set_config (cfg , ** parse_kv_flags (fv .recorder_spec , delimiter = "=" ))
84+ parsed_flags = parse_kv_flags (fv .recorder_spec , delimiter = "=" )
85+ if "upload_interval" in parsed_flags :
86+ parsed_flags ["upload_interval" ] = int (parsed_flags ["upload_interval" ])
87+ if "rolling_window_size" in parsed_flags and isinstance (
88+ parsed_flags ["rolling_window_size" ], str
89+ ):
90+ parsed_flags ["rolling_window_size" ] = [
91+ int (x ) for x in parsed_flags ["rolling_window_size" ].split ("," )
92+ ]
93+ cfg = maybe_set_config (cfg , ** parsed_flags )
6494 return cfg .instantiate ()
6595
6696 def __init__ (self , cfg ):
6797 super ().__init__ (cfg )
68- cfg : GoodputRecorder .Config = self .config
69- self ._recorder = None
70- self ._monitor = None
71-
72- def record (self , event : measurement .Event , * args , ** kwargs ):
73- # Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
98+ self ._recorder : Optional [goodput .GoodputRecorder ] = None
99+ self ._monitor : Optional [goodput_monitoring .GoodputMonitor ] = None
100+ self ._rolling_window_monitor : Optional [goodput_monitoring .GoodputMonitor ] = None
101+
102+ @contextlib .contextmanager
103+ def record_event (self , event : measurement .Event , * args , ** kwargs ):
104+ """Records a goodput event using a context manager."""
105+ # Lazily instantiate the recorder if it hasn't been already.
74106 if self ._recorder is None :
75- cfg : GoodputRecorder .Config = self .config
107+ if jax .process_index () == 0 :
108+ logging .info ("Lazily instantiating goodput recorder." )
76109 self ._recorder = goodput .GoodputRecorder (
77- job_name = cfg .name ,
78- logger_name = f"goodput_logger_{ cfg .name } " ,
110+ job_name = self . config .name ,
111+ logger_name = f"goodput_logger_{ self . config .name } " ,
79112 logging_enabled = (jax .process_index () == 0 ),
80113 )
81114
82- if event == measurement .Event .START_JOB :
83- self ._recorder .record_job_start_time (* args , ** kwargs )
84- elif event == measurement .Event .END_JOB :
85- self ._recorder .record_job_end_time (* args , ** kwargs )
86- elif event == measurement .Event .START_STEP :
87- self ._recorder .record_step_start_time (* args , ** kwargs )
88- elif event == measurement .Event .START_ACCELERATOR_INIT :
89- self ._recorder .record_tpu_init_start_time (* args , ** kwargs )
90- elif event == measurement .Event .END_ACCELERATOR_INIT :
91- self ._recorder .record_tpu_init_end_time (* args , ** kwargs )
92- elif event == measurement .Event .START_TRAINING_PREPARATION :
93- self ._recorder .record_training_preparation_start_time (* args , ** kwargs )
94- elif event == measurement .Event .END_TRAINING_PREPARATION :
95- self ._recorder .record_training_preparation_end_time (* args , ** kwargs )
96- elif event == measurement .Event .START_DATA_LOADING :
97- self ._recorder .record_data_loading_start_time (* args , ** kwargs )
98- elif event == measurement .Event .END_DATA_LOADING :
99- self ._recorder .record_data_loading_end_time (* args , ** kwargs )
100- elif event == measurement .Event .START_CUSTOM_BADPUT_EVENT :
101- self ._recorder .record_custom_badput_event_start_time (* args , ** kwargs )
102- elif event == measurement .Event .END_CUSTOM_BADPUT_EVENT :
103- self ._recorder .record_custom_badput_event_end_time (* args , ** kwargs )
104- else :
105- logging .log_first_n (
106- logging .WARNING ,
107- "Ignoring unknown event %s" ,
108- 1 ,
109- event ,
115+ start_method_name = f"record_{ event .value } _start_time"
116+ end_method_name = f"record_{ event .value } _end_time"
117+
118+ record_event_start = getattr (self ._recorder , start_method_name , None )
119+ record_event_end = getattr (self ._recorder , end_method_name , None )
120+
121+ try :
122+ if record_event_start :
123+ record_event_start (* args , ** kwargs )
124+ except (TypeError , ValueError , RuntimeError ) as e :
125+ logging .warning (
126+ "Failed to record start of event %s. Error: %s" , event .name , e , exc_info = True
110127 )
111128
112- def start_monitoring (self , * args , ** kwargs ):
113- """Starts Monitoring of Goodput.
129+ try :
130+ yield
131+ finally :
132+ try :
133+ if record_event_end :
134+ record_event_end (* args , ** kwargs )
135+ except (TypeError , ValueError , RuntimeError ) as e :
136+ logging .warning (
137+ "Failed to record end of event %s. Error: %s" , event .name , e , exc_info = True
138+ )
139+
140+ @contextlib .contextmanager
141+ def maybe_monitor_goodput (self , * args , ** kwargs ):
142+ """Monitor cumulative goodput if enabled.
114143
115144 Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
116- Goodput and Badput at the upload_interval and upload to the specified TensorBoard
117- directory.
145+ Goodput, Badput, Step & Disruption Information at the upload_interval to the
146+ specified TensorBoard directory and Google Cloud Monitoring .
118147 Note: This function requires initialization of distributed JAX before it is called.
119148 If there are internal GCP errors from querying and uploading data, these will be
120149 logged without affecting the workload. GoodputMonitor logs will provide further
@@ -123,33 +152,66 @@ def start_monitoring(self, *args, **kwargs):
123152 Default behavior is to push metrics to Google Cloud Monitoring.
124153 This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
125154 """
126- cfg : GoodputRecorder .Config = self .config
127- include_step_deviation = True
128- if jax .process_index () == 0 :
155+ if jax .process_index () != 0 :
156+ yield
157+ return
158+ try :
129159 if self ._monitor is None :
130- if int (cfg .step_deviation_interval_seconds ) <= 0 :
131- include_step_deviation = False
132-
133160 gcp_options = goodput_monitoring .GCPOptions (
134- enable_gcp_goodput_metrics = True ,
135- enable_gcp_step_deviation_metrics = include_step_deviation ,
161+ enable_gcp_goodput_metrics = self .config .enable_gcp_goodput_metrics ,
136162 )
137163 self ._monitor = goodput_monitoring .GoodputMonitor (
138- job_name = cfg .name ,
139- logger_name = f"goodput_logger_{ cfg .name } " ,
140- tensorboard_dir = cfg .upload_dir ,
141- upload_interval = int ( cfg . upload_interval ) ,
164+ job_name = self . config .name ,
165+ logger_name = f"goodput_logger_{ self . config .name } " ,
166+ tensorboard_dir = self . config .upload_dir ,
167+ upload_interval = self . config . upload_interval ,
142168 monitoring_enabled = True ,
143- include_badput_breakdown = True ,
144- include_step_deviation = include_step_deviation ,
145- step_deviation_interval_seconds = int (cfg .step_deviation_interval_seconds ),
169+ pathway_enabled = self .config .enable_pathways_goodput ,
170+ include_badput_breakdown = self .config .include_badput_breakdown ,
146171 gcp_options = gcp_options ,
147172 )
148173
149174 self ._monitor .start_goodput_uploader (* args , ** kwargs )
150175 logging .info ("Started Goodput upload to Tensorboard & GCM in the background!" )
151- if include_step_deviation :
152- self ._monitor .start_step_deviation_uploader (* args , ** kwargs )
176+ yield
177+ finally :
178+ if self ._monitor :
179+ self ._monitor .stop_goodput_uploader ()
180+ logging .info ("Flushed final metrics and safe exited from Goodput monitoring." )
181+
182+ @contextlib .contextmanager
183+ def maybe_monitor_rolling_window_goodput (self ):
184+ """Monitor rolling window goodput if enabled."""
185+ if not self .config .enable_rolling_window_goodput_monitoring or jax .process_index () != 0 :
186+ yield
187+ return
188+ try :
189+ if self ._rolling_window_monitor is None :
190+ rolling_window_tensorboard_dir = os .path .join (
191+ self .config .upload_dir , f"rolling_window_{ self .config .name } "
192+ )
193+ self ._rolling_window_monitor = goodput_monitoring .GoodputMonitor (
194+ job_name = self .config .name ,
195+ logger_name = f"goodput_logger_{ self .config .name } " ,
196+ tensorboard_dir = rolling_window_tensorboard_dir ,
197+ upload_interval = self .config .upload_interval ,
198+ monitoring_enabled = True ,
199+ pathway_enabled = self .config .enable_pathways_goodput ,
200+ include_badput_breakdown = True ,
201+ )
202+ self ._rolling_window_monitor .start_rolling_window_goodput_uploader (
203+ self .config .rolling_window_size
204+ )
205+ logging .info ("Started Rolling Window Goodput monitoring in the background!" )
206+ yield
207+ finally :
208+ if self ._rolling_window_monitor :
209+ self ._rolling_window_monitor .stop_rolling_window_goodput_uploader ()
153210 logging .info (
154- "Started Step Deviation upload to Tensorboard & GCM in the background! "
211+ "Flushed final metrics and safe exited from Rolling Window Goodput monitoring. "
155212 )
213+
214+
215+ def create_goodput_recorder (cfg : GoodputRecorder .Config ):
216+ """Factory method to create GoodputRecorder."""
217+ return GoodputRecorder (cfg )
0 commit comments