1616class GoodputRecorderTest (parameterized .TestCase ):
1717 """Tests GoodputRecorder."""
1818
19- def test_from_flags_with_spec (self ):
20- """Tests that flags are correctly parsed."""
21- fv = flags .FlagValues ()
22- measurement .define_flags (flag_values = fv )
23- fv .set_default (
24- "recorder_spec" ,
25- [
19+ def setUp (self ):
20+ super ().setUp ()
21+ self .mock_flags = mock .MagicMock (spec = flags .FlagValues )
22+ self .mock_flags .jax_backend = None
23+ self .patch_flags_global = mock .patch .object (flags , "FLAGS" , new = self .mock_flags )
24+ self .patch_flags_global .start ()
25+
26+ def tearDown (self ):
27+ super ().tearDown ()
28+ self .patch_flags_global .stop ()
29+
30+ @parameterized .parameters (
31+ dict (
32+ recorder_spec = [
33+ "name=test-name" ,
34+ "upload_dir=/test/path" ,
35+ "upload_interval=15" ,
36+ ],
37+ expected_rolling_window_size = [],
38+ ),
39+ dict (
40+ recorder_spec = [
2641 "name=test-name" ,
2742 "upload_dir=/test/path" ,
2843 "upload_interval=15" ,
29- "enable_rolling_window_goodput_monitoring=True" ,
3044 "rolling_window_size=1,2,3" ,
3145 ],
32- )
33- fv .mark_as_parsed ()
34- recorder = GoodputRecorder .from_flags (fv )
46+ expected_rolling_window_size = [1 , 2 , 3 ],
47+ ),
48+ )
49+ def test_from_flags (
50+ self ,
51+ recorder_spec ,
52+ expected_rolling_window_size ,
53+ ):
54+ """Tests that flags are correctly parsed into the config."""
55+ mock_fv = mock .MagicMock (spec = flags .FlagValues )
56+ mock_fv .recorder_spec = recorder_spec
57+ mock_fv .jax_backend = "tpu"
58+
59+ recorder = GoodputRecorder .from_flags (mock_fv )
60+
3561 self .assertEqual ("test-name" , recorder .config .name )
3662 self .assertEqual ("/test/path" , recorder .config .upload_dir )
3763 self .assertEqual (15 , recorder .config .upload_interval )
38- self .assertTrue (recorder .config .enable_rolling_window_goodput_monitoring )
39- self .assertEqual ([1 , 2 , 3 ], recorder .config .rolling_window_size )
64+ self .assertEqual (expected_rolling_window_size , recorder .config .rolling_window_size )
4065
4166 def test_from_flags_missing_required (self ):
4267 """Tests that missing required flags raise an error."""
43- fv = flags .FlagValues ()
44- measurement .define_flags (flag_values = fv )
45- fv .set_default ("recorder_spec" , ["name=test-name" ]) # Missing upload_dir/interval
46- fv .mark_as_parsed ()
68+ mock_fv = mock .MagicMock (spec = flags .FlagValues )
69+ mock_fv .recorder_spec = ["name=test-name" ] # Missing upload_dir/interval
70+ mock_fv .jax_backend = "tpu"
4771 with self .assertRaisesRegex (RequiredFieldMissingError , "upload_dir" ):
48- GoodputRecorder .from_flags (fv )
72+ GoodputRecorder .from_flags (mock_fv )
4973
5074 @mock .patch ("jax.process_index" , return_value = 0 )
5175 def test_record_event_context_manager (self , _ ):
5276 """Tests the record_event context manager."""
53- recorder = GoodputRecorder (GoodputRecorder .default_config ().set (name = "test" ))
54- with mock .patch ("ml_goodput_measurement.goodput.GoodputRecorder" ) as mock_recorder :
55- mock_instance = mock_recorder .return_value
77+ cfg = GoodputRecorder .default_config ().set (
78+ name = "test" ,
79+ upload_dir = "/tmp/test" ,
80+ upload_interval = 1 ,
81+ )
82+ recorder = GoodputRecorder (cfg )
83+ with mock .patch ("ml_goodput_measurement.goodput.GoodputRecorder" ) as mock_recorder_cls :
84+ mock_instance = mock_recorder_cls .return_value
5685 with recorder .record_event (measurement .Event .JOB ):
5786 pass
58- mock_recorder .assert_called_once ()
87+ mock_recorder_cls .assert_called_once ()
5988 mock_instance .record_job_start_time .assert_called_once ()
6089 mock_instance .record_job_end_time .assert_called_once ()
6190
91+ @parameterized .parameters (
92+ dict (is_pathways_job = False , mock_jax_backend = "tpu" ),
93+ dict (is_pathways_job = True , mock_jax_backend = "proxy" ),
94+ )
6295 @mock .patch ("jax.process_index" , return_value = 0 )
63- def test_maybe_monitor_goodput (self , _ ):
96+ def test_maybe_monitor_goodput (self , _ , is_pathways_job , mock_jax_backend ):
6497 """Tests the maybe_monitor_goodput context manager."""
98+ self .mock_flags .jax_backend = mock_jax_backend
6599 cfg = GoodputRecorder .default_config ().set (
66100 name = "test-monitor" ,
67101 upload_dir = "/test" ,
68102 upload_interval = 30 ,
69- enable_gcp_goodput_metrics = True ,
70- enable_pathways_goodput = False ,
71- include_badput_breakdown = True ,
72103 )
73104 recorder = GoodputRecorder (cfg )
74105
@@ -84,47 +115,82 @@ def test_maybe_monitor_goodput(self, _):
84115 tensorboard_dir = "/test" ,
85116 upload_interval = 30 ,
86117 monitoring_enabled = True ,
87- pathway_enabled = False ,
118+ pathway_enabled = is_pathways_job ,
88119 include_badput_breakdown = True ,
89- gcp_options = mock .ANY ,
90120 )
91- # Verify the start and stop methods were called.
92121 mock_monitor_instance .start_goodput_uploader .assert_called_once ()
93122 mock_monitor_instance .stop_goodput_uploader .assert_called_once ()
94123
124+ @parameterized .parameters (
125+ dict (
126+ is_rolling_window_enabled = True ,
127+ rolling_window_size = [10 , 20 ],
128+ is_pathways_job = False ,
129+ mock_jax_backend = "tpu" ,
130+ ),
131+ dict (
132+ is_rolling_window_enabled = False ,
133+ rolling_window_size = [],
134+ is_pathways_job = False ,
135+ mock_jax_backend = "tpu" ,
136+ ),
137+ dict (
138+ is_rolling_window_enabled = True ,
139+ rolling_window_size = [50 ],
140+ is_pathways_job = True ,
141+ mock_jax_backend = "proxy" ,
142+ ),
143+ )
95144 @mock .patch ("jax.process_index" , return_value = 0 )
96- def test_maybe_monitor_rolling_window (self , _ ):
145+ def test_maybe_monitor_rolling_window (
146+ self ,
147+ mock_process_index ,
148+ is_rolling_window_enabled ,
149+ rolling_window_size ,
150+ is_pathways_job ,
151+ mock_jax_backend ,
152+ ): # pylint: disable=unused-argument
97153 """Tests the rolling window monitoring context manager."""
154+ self .mock_flags .jax_backend = mock_jax_backend
98155 cfg = GoodputRecorder .default_config ().set (
99156 name = "test-rolling" ,
100157 upload_dir = "/test" ,
101158 upload_interval = 30 ,
102- enable_rolling_window_goodput_monitoring = True ,
103- rolling_window_size = [10 , 20 ],
159+ rolling_window_size = rolling_window_size ,
104160 )
105161 recorder = GoodputRecorder (cfg )
106162
107163 with mock .patch ("ml_goodput_measurement.monitoring.GoodputMonitor" ) as mock_monitor_cls :
108164 mock_monitor_instance = mock_monitor_cls .return_value
165+ if not is_rolling_window_enabled :
166+ with recorder .maybe_monitor_rolling_window_goodput ():
167+ pass
168+ mock_monitor_cls .assert_not_called ()
169+ return
109170 with recorder .maybe_monitor_rolling_window_goodput ():
110171 pass
111172
112- # Verify that GoodputMonitor was instantiated for rolling window.
113- mock_monitor_cls .assert_called_once ()
114- self .assertEqual (
115- "/test/rolling_window_test-rolling" ,
116- mock_monitor_cls .call_args .kwargs ["tensorboard_dir" ],
173+ mock_monitor_cls .assert_called_once_with (
174+ job_name = "test-rolling" ,
175+ logger_name = "goodput_logger_test-rolling" ,
176+ tensorboard_dir = "/test/rolling_window_test-rolling" ,
177+ upload_interval = 30 ,
178+ monitoring_enabled = True ,
179+ pathway_enabled = is_pathways_job ,
180+ include_badput_breakdown = True ,
117181 )
118182
119- # Verify the correct the start and stop methods were called.
120- mock_monitor_instance .start_rolling_window_goodput_uploader .assert_called_with ([10 , 20 ])
183+ mock_monitor_instance .start_rolling_window_goodput_uploader .assert_called_with (
184+ rolling_window_size
185+ )
121186 mock_monitor_instance .stop_rolling_window_goodput_uploader .assert_called_once ()
122187
123188 @mock .patch ("jax.process_index" , return_value = 1 )
124189 def test_non_zero_process_index_skips_monitoring (
125190 self , mock_process_index
126191 ): # pylint: disable=unused-argument
127192 """Tests that monitoring is skipped on non-zero process indices."""
193+ self .mock_flags .jax_backend = "tpu"
128194 cfg = GoodputRecorder .default_config ().set (
129195 name = "test" , upload_dir = "/test" , upload_interval = 30
130196 )
@@ -136,8 +202,13 @@ def test_non_zero_process_index_skips_monitoring(
136202 pass
137203 mock_monitor_cls .assert_not_called ()
138204
139- # Test maybe_monitor_rolling_window_goodput
140- recorder .config .enable_rolling_window_goodput_monitoring = True
141- with recorder .maybe_monitor_rolling_window_goodput ():
205+ cfg_rolling = GoodputRecorder .default_config ().set (
206+ name = "test-rolling-skip" ,
207+ upload_dir = "/test" ,
208+ upload_interval = 30 ,
209+ rolling_window_size = [10 , 20 ],
210+ )
211+ recorder_rolling = GoodputRecorder (cfg_rolling )
212+ with recorder_rolling .maybe_monitor_rolling_window_goodput ():
142213 pass
143214 mock_monitor_cls .assert_not_called ()
0 commit comments