@@ -432,6 +432,36 @@ def test_set_tracking_uri(mlflow_mock):
432432 mlflow_mock .set_tracking_uri .assert_called_with ("the_tracking_uri" )
433433
434434
435+ @mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
436+ def test_mlflow_log_model_with_checkpoint_path_prefix (mlflow_mock , tmp_path ):
437+ """Test that the logger creates the folders and files in the right place with a prefix."""
438+ client = mlflow_mock .tracking .MlflowClient
439+
440+ # Get model, logger, trainer and train
441+ model = BoringModel ()
442+ logger = MLFlowLogger ("test" , save_dir = str (tmp_path ), log_model = "all" , checkpoint_path_prefix = "my_prefix" )
443+ logger = mock_mlflow_run_creation (logger , experiment_id = "test-id" )
444+
445+ trainer = Trainer (
446+ default_root_dir = tmp_path ,
447+ logger = logger ,
448+ max_epochs = 2 ,
449+ limit_train_batches = 3 ,
450+ limit_val_batches = 3 ,
451+ )
452+ trainer .fit (model )
453+
454+ # Checkpoint log
455+ assert client .return_value .log_artifact .call_count == 2
456+ # Metadata and aliases log
457+ assert client .return_value .log_artifacts .call_count == 2
458+
459+ # Check that the prefix is used in the artifact path
460+ for call in client .return_value .log_artifact .call_args_list :
461+ args , _ = call
462+ assert str (args [2 ]).startswith ("my_prefix" )
463+
464+
435465@mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
436466def test_mlflow_multiple_checkpoints_top_k (mlflow_mock , tmp_path ):
437467 """Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger.
0 commit comments