Skip to content

Commit c9452c7

Browse files
dberenbaumdberenbaum
authored andcommitted
lightning: drop unused checkpoints
1 parent cf1d6d5 commit c9452c7

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

src/dvclive/lightning.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa: ARG002
22
import inspect
3-
from typing import Any, Dict, Optional, Union
3+
from pathlib import Path
4+
from typing import Any, Dict, List, Optional, Union
45

56
try:
67
from lightning.fabric.utilities.logger import (
@@ -10,6 +11,7 @@
1011
)
1112
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
1213
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
14+
from lightning.pytorch.loggers.utilities import _scan_checkpoints
1315
from lightning.pytorch.utilities import rank_zero_only
1416
except ImportError:
1517
from lightning_fabric.utilities.logger import (
@@ -20,6 +22,7 @@
2022
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2123
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
2224
from pytorch_lightning.utilities import rank_zero_only
25+
from pytorch_lightning.utilities.logger import _scan_checkpoints
2326
from torch import is_tensor
2427

2528
from dvclive import Live
@@ -45,7 +48,7 @@ def _should_call_next_step():
4548

4649

4750
class DVCLiveLogger(Logger):
48-
def __init__(
51+
def __init__( # noqa: PLR0913
4952
self,
5053
run_name: Optional[str] = "dvclive_run",
5154
prefix="",
@@ -75,7 +78,9 @@ def __init__(
7578
# Force Live instantiation
7679
self.experiment # noqa: B018
7780
self._log_model = log_model
81+
self._logged_model_time: Dict[str, float] = {}
7882
self._checkpoint_callback: Optional[ModelCheckpoint] = None
83+
self._all_checkpoint_paths: List[str] = []
7984

8085
@property
8186
def name(self):
@@ -140,18 +145,37 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
140145
if self._log_model == "all" or (
141146
self._log_model is True and checkpoint_callback.save_top_k == -1
142147
):
143-
self.experiment.log_artifact(checkpoint_callback.dirpath)
148+
self._save_checkpoints(checkpoint_callback)
144149

145150
@rank_zero_only
146151
def finalize(self, status: str) -> None:
147152
checkpoint_callback = self._checkpoint_callback
148153
# Save model checkpoints.
149154
if self._log_model is True:
150-
self.experiment.log_artifact(checkpoint_callback.dirpath)
155+
self._save_checkpoints(checkpoint_callback)
151156
# Log best model.
152157
if self._log_model in (True, "all"):
153158
best_model_path = checkpoint_callback.best_model_path
154159
self.experiment.log_artifact(
155160
best_model_path, name="best", type="model", cache=False
156161
)
157162
self.experiment.end()
163+
164+
def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
165+
# get checkpoints to be saved with associated score
166+
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
167+
168+
# update model time and append path to list of all checkpoints
169+
for t, p, _, _ in checkpoints:
170+
self._logged_model_time[p] = t
171+
self._all_checkpoint_paths.append(p)
172+
173+
def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
174+
# drop unused checkpoints
175+
if not self._experiment._resume: # noqa: SLF001
176+
for p in Path(checkpoint_callback.dirpath).iterdir():
177+
if str(p) not in self._all_checkpoint_paths:
178+
p.unlink(missing_ok=True)
179+
180+
# save directory
181+
self.experiment.log_artifact(checkpoint_callback.dirpath)

0 commit comments

Comments
 (0)