@@ -373,3 +373,46 @@ def test_maybe_monitor_all(
373373 else :
374374 mock_monitor_instance .start_rolling_window_goodput_uploader .assert_not_called ()
375375 mock_monitor_instance .stop_rolling_window_goodput_uploader .assert_not_called ()
376+
377+ @mock .patch ("jax.process_index" , return_value = 0 )
378+ def test_create_checkpoint_logger_success (self , _ ):
379+ """Tests that create_checkpoint_logger creates a CloudLogger with correct config."""
380+ cfg = GoodputRecorder .default_config ().set (
381+ name = "test-job" ,
382+ upload_dir = "/test" ,
383+ upload_interval = 30 ,
384+ )
385+ recorder = GoodputRecorder (cfg )
386+
387+ with mock .patch ("orbax.checkpoint.logging.CloudLogger" ) as mock_logger_cls :
388+ mock_logger_instance = mock_logger_cls .return_value
389+ logger = recorder .create_checkpoint_logger ()
390+
391+ mock_logger_cls .assert_called_once ()
392+ self .assertIs (logger , mock_logger_instance )
393+
394+ _ , kwargs = mock_logger_cls .call_args
395+ options = kwargs ["options" ]
396+ self .assertEqual (options .job_name , "test-job" )
397+ self .assertEqual (options .logger_name , "goodput_logger_test-job" )
398+
399+ @mock .patch ("jax.process_index" , return_value = 0 )
400+ def test_create_checkpoint_logger_failure (self , _ ):
401+ """Tests that create_checkpoint_logger logs a warning on failure and returns None."""
402+ cfg = GoodputRecorder .default_config ().set (
403+ name = "fail-job" ,
404+ upload_dir = "/test" ,
405+ upload_interval = 30 ,
406+ )
407+ recorder = GoodputRecorder (cfg )
408+
409+ with mock .patch (
410+ "orbax.checkpoint.logging.CloudLogger" , side_effect = RuntimeError ("TestError" )
411+ ) as mock_logger_cls , mock .patch .object (logging , "warning" ) as mock_warning :
412+ logger = recorder .create_checkpoint_logger ()
413+ self .assertIsNone (logger )
414+ mock_logger_cls .assert_called_once ()
415+ mock_warning .assert_called_once ()
416+ self .assertIn (
417+ "Failed to create Goodput checkpoint logger" , mock_warning .call_args [0 ][0 ]
418+ )
0 commit comments