Skip to content
41 changes: 41 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down
83 changes: 83 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,33 @@ def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
self.layer = torch.nn.Linear(32, out_dim)


class AdaptHparamsModel(LightningModule):
"""Simple model for testing adapt_checkpoint_hparams hook without dynamic neural network layers.

This model stores hyperparameters as attributes without creating layers that would cause size mismatches when
hyperparameters are changed between fit and predict phases.

"""

def __init__(self, out_dim: int = 8, hidden_dim: int = 16) -> None:
super().__init__()
self.save_hyperparameters()
self.out_dim = out_dim
self.hidden_dim = hidden_dim
# Add a simple layer that doesn't depend on hyperparameters
self.layer = torch.nn.Linear(10, 2)

def forward(self, x):
return self.layer(x)

def training_step(self, batch, batch_idx):
x, y = batch
return torch.nn.functional.mse_loss(self(x), y)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1)


def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
Expand Down Expand Up @@ -562,6 +589,62 @@ def add_arguments_to_parser(self, parser):
assert cli.model.layer.out_features == 4


def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir):
"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""

class AdaptHparamsCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Remove out_dim and hidden_dim for non-fit subcommands."""
if subcommand != "fit":
checkpoint_hparams.pop("out_dim", None)
checkpoint_hparams.pop("hidden_dim", None)
return checkpoint_hparams

# First, create a checkpoint by running fit
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 3

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses adapted hparams (without out_dim and hidden_dim)
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
assert cli.config.predict.model.out_dim == 5
assert cli.config.predict.model.hidden_dim == 10


def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""

class AdaptHparamsEmptyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Disable checkpoint hyperparameter loading."""
return {}

# First, create a checkpoint
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses default values when hook returns empty dict
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

# Model should use default values (out_dim=8, hidden_dim=16)
assert cli.config_init.predict.model.out_dim == 8
assert cli.config_init.predict.model.hidden_dim == 16


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down