11# ruff: noqa: ARG002
22import inspect
3- from typing import Any , Dict , Optional , Union
3+ from pathlib import Path
4+ from typing import Any , Dict , List , Optional , Union
45
56try :
67 from lightning .fabric .utilities .logger import (
1011 )
1112 from lightning .pytorch .callbacks .model_checkpoint import ModelCheckpoint
1213 from lightning .pytorch .loggers .logger import Logger , rank_zero_experiment
14+ from lightning .pytorch .loggers .utilities import _scan_checkpoints
1315 from lightning .pytorch .utilities import rank_zero_only
1416except ImportError :
1517 from lightning_fabric .utilities .logger import (
2022 from pytorch_lightning .callbacks .model_checkpoint import ModelCheckpoint
2123 from pytorch_lightning .loggers .logger import Logger , rank_zero_experiment
2224 from pytorch_lightning .utilities import rank_zero_only
25+ from pytorch_lightning .utilities .logger import _scan_checkpoints
2326from torch import is_tensor
2427
2528from dvclive import Live
@@ -45,7 +48,7 @@ def _should_call_next_step():
4548
4649
4750class DVCLiveLogger (Logger ):
48- def __init__ (
51+ def __init__ ( # noqa: PLR0913
4952 self ,
5053 run_name : Optional [str ] = "dvclive_run" ,
5154 prefix = "" ,
@@ -75,7 +78,9 @@ def __init__(
7578 # Force Live instantiation
7679 self .experiment # noqa: B018
7780 self ._log_model = log_model
81+ self ._logged_model_time : Dict [str , float ] = {}
7882 self ._checkpoint_callback : Optional [ModelCheckpoint ] = None
83+ self ._all_checkpoint_paths : List [str ] = []
7984
8085 @property
8186 def name (self ):
@@ -140,18 +145,37 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
140145 if self ._log_model == "all" or (
141146 self ._log_model is True and checkpoint_callback .save_top_k == - 1
142147 ):
143- self .experiment . log_artifact (checkpoint_callback . dirpath )
148+ self ._save_checkpoints (checkpoint_callback )
144149
145150 @rank_zero_only
146151 def finalize (self , status : str ) -> None :
147152 checkpoint_callback = self ._checkpoint_callback
148153 # Save model checkpoints.
149154 if self ._log_model is True :
150- self .experiment . log_artifact (checkpoint_callback . dirpath )
155+ self ._save_checkpoints (checkpoint_callback )
151156 # Log best model.
152157 if self ._log_model in (True , "all" ):
153158 best_model_path = checkpoint_callback .best_model_path
154159 self .experiment .log_artifact (
155160 best_model_path , name = "best" , type = "model" , cache = False
156161 )
157162 self .experiment .end ()
163+
164+ def _scan_checkpoints (self , checkpoint_callback : ModelCheckpoint ) -> None :
165+ # get checkpoints to be saved with associated score
166+ checkpoints = _scan_checkpoints (checkpoint_callback , self ._logged_model_time )
167+
168+ # update model time and append path to list of all checkpoints
169+ for t , p , _ , _ in checkpoints :
170+ self ._logged_model_time [p ] = t
171+ self ._all_checkpoint_paths .append (p )
172+
173+ def _save_checkpoints (self , checkpoint_callback : ModelCheckpoint ) -> None :
174+ # drop unused checkpoints
175+ if not self ._experiment ._resume : # noqa: SLF001
176+ for p in Path (checkpoint_callback .dirpath ).iterdir ():
177+ if str (p ) not in self ._all_checkpoint_paths :
178+ p .unlink (missing_ok = True )
179+
180+ # save directory
181+ self .experiment .log_artifact (checkpoint_callback .dirpath )
0 commit comments