Skip to content

Commit 864ddb5

Browse files
committed
integrations: Standardize metric name to group under subfolders.
1 parent 163da5f commit 864ddb5

File tree

9 files changed

+49
-16
lines changed

9 files changed

+49
-16
lines changed

dvclive/fastai.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from fastai.callback.core import Callback
22

33
from dvclive import Live
4+
from dvclive.utils import standardize_metric_name
45

56

67
class DvcLiveCallback(Callback):
@@ -13,8 +14,9 @@ def after_epoch(self):
1314
for key, value in zip(
1415
self.learn.recorder.metric_names, self.learn.recorder.log
1516
):
16-
key = key.replace("_", "/")
17-
self.dvclive.log(f"{key}", float(value))
17+
self.dvclive.log(
18+
standardize_metric_name(key, __name__), float(value)
19+
)
1820

1921
if self.model_file:
2022
self.learn.save(self.model_file)

dvclive/huggingface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77

88
from dvclive import Live
9+
from dvclive.utils import standardize_metric_name
910

1011

1112
class DvcLiveCallback(TrainerCallback):
@@ -23,7 +24,7 @@ def on_log(
2324
):
2425
logs = kwargs["logs"]
2526
for key, value in logs.items():
26-
self.dvclive.log(key, value)
27+
self.dvclive.log(standardize_metric_name(key, __name__), value)
2728
self.dvclive.next_step()
2829

2930
def on_epoch_end(

dvclive/keras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99

1010
from dvclive import Live
11+
from dvclive.utils import standardize_metric_name
1112

1213

1314
class DvcLiveCallback(Callback):
@@ -39,7 +40,7 @@ def on_epoch_end(
3940
): # pylint: disable=unused-argument
4041
logs = logs or {}
4142
for metric, value in logs.items():
42-
self.dvclive.log(metric, value)
43+
self.dvclive.log(standardize_metric_name(metric, __name__), value)
4344
if self.model_file:
4445
if self.save_weights_only:
4546
self.model.save_weights(self.model_file)

dvclive/lightning.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33
from pytorch_lightning.loggers import LightningLoggerBase
44
from pytorch_lightning.loggers.base import rank_zero_experiment
55
from pytorch_lightning.utilities import rank_zero_only
6-
from pytorch_lightning.utilities.logger import _add_prefix
76
from torch import is_tensor
87

98
from dvclive import Live
9+
from dvclive.utils import standardize_metric_name
1010

1111

1212
class DvcLiveLogger(LightningLoggerBase):
13-
14-
LOGGER_JOIN_CHAR = "-"
15-
1613
def __init__(
1714
self,
1815
run_name: Optional[str] = "dvclive_run",
@@ -70,9 +67,9 @@ def log_metrics(
7067
rank_zero_only.rank == 0
7168
), "experiment tried to log from global_rank != 0"
7269

73-
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
7470
for metric_name, metric_val in metrics.items():
7571
if is_tensor(metric_val):
7672
metric_val = metric_val.cpu().detach().item()
73+
metric_name = standardize_metric_name(metric_name, __name__)
7774
self.experiment.log(name=metric_name, val=metric_val)
7875
self.experiment.next_step()

dvclive/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,32 @@ def env2bool(var, undefined=False):
7373
if var is None:
7474
return undefined
7575
return bool(re.search("1|y|yes|true", var, flags=re.I))
76+
77+
78+
def standardize_metric_name(metric_name: str, framework: str) -> str:
79+
"""Map framework-specific format to DVCLive standard.
80+
81+
Use `{split}/` as prefix in order to seperate by subfolders.
82+
Use `{train|eval}` as split name.
83+
"""
84+
if framework == "dvclive.fastai":
85+
metric_name = metric_name.replace("train_", "train/")
86+
metric_name = metric_name.replace("valid_", "eval/")
87+
88+
elif framework == "dvclive.huggingface":
89+
for split in {"train", "eval"}:
90+
metric_name = metric_name.replace(f"{split}_", f"{split}/")
91+
92+
elif framework == "dvclive.keras":
93+
if "val_" in metric_name:
94+
metric_name = metric_name.replace("val_", "eval/")
95+
else:
96+
metric_name = f"train/{metric_name}"
97+
98+
elif framework == "dvclive.lightning":
99+
parts = metric_name.split("_")
100+
if len(parts) > 2:
101+
split, *rest, freq = parts
102+
metric_name = f"{split}/{freq}/{'_'.join(rest)}"
103+
104+
return metric_name

tests/test_fastai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_fastai_callback(tmp_dir, data_loader):
4848
assert os.path.exists("dvclive")
4949

5050
train_path = tmp_dir / "dvclive" / Scalar.subfolder / "train"
51-
valid_path = tmp_dir / "dvclive" / Scalar.subfolder / "valid"
51+
valid_path = tmp_dir / "dvclive" / Scalar.subfolder / "eval"
5252

5353
assert train_path.is_dir()
5454
assert valid_path.is_dir()

tests/test_huggingface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ def test_huggingface_integration(tmp_dir, model, args, data, tokenizer):
8282
logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
8383

8484
assert len(logs) == 10
85-
assert "eval_matthews_correlation" in logs
86-
assert "eval_loss" in logs
85+
assert os.path.join("eval", "matthews_correlation") in logs
86+
assert os.path.join("eval", "loss") in logs
87+
assert os.path.join("train", "loss") in logs
8788
assert len(logs["epoch"]) == 3
88-
assert len(logs["eval_loss"]) == 2
89+
assert len(logs[os.path.join("eval", "loss")]) == 2
8990

9091

9192
def test_huggingface_model_file(tmp_dir, model, args, data, tokenizer, mocker):

tests/test_keras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ def test_keras_callback(tmp_dir, xor_model, capture_wrap):
4242
y,
4343
epochs=1,
4444
batch_size=1,
45+
validation_split=0.2,
4546
callbacks=[DvcLiveCallback()],
4647
)
4748

4849
assert os.path.exists("dvclive")
4950
logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
5051

51-
assert "accuracy" in logs
52+
assert os.path.join("train", "accuracy") in logs
53+
assert os.path.join("eval", "accuracy") in logs
5254

5355

5456
@pytest.mark.parametrize("save_weights_only", (True, False))

tests/test_lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ def test_lightning_integration(tmp_dir):
9696
logs, _ = read_logs(tmp_dir / "logs" / Scalar.subfolder)
9797

9898
assert len(logs) == 3
99-
assert "train_loss_step" in logs
100-
assert "train_loss_epoch" in logs
99+
assert os.path.join("train", "epoch", "loss") in logs
100+
assert os.path.join("train", "step", "loss") in logs
101101
assert "epoch" in logs

0 commit comments

Comments
 (0)