Allow access to ckpt_path within context of fit() (#11696)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
7da931d1ca
commit
3bc2407239
|
@ -258,6 +258,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))
|
||||
|
||||
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))
|
||||
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ class BaseFinetuning(Callback):
|
|||
if self._restarting:
|
||||
named_parameters = dict(pl_module.named_parameters())
|
||||
for opt_idx, optimizer in enumerate(trainer.optimizers):
|
||||
param_groups = self.__apply_mapping_to_param_groups(
|
||||
param_groups = self._apply_mapping_to_param_groups(
|
||||
self._internal_optimizer_metadata[opt_idx], named_parameters
|
||||
)
|
||||
optimizer.param_groups = param_groups
|
||||
|
@ -245,7 +245,7 @@ class BaseFinetuning(Callback):
|
|||
self.freeze_before_training(pl_module)
|
||||
|
||||
@staticmethod
|
||||
def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
|
||||
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
|
||||
output = []
|
||||
for g in param_groups:
|
||||
# skip params to save memory
|
||||
|
@ -263,13 +263,13 @@ class BaseFinetuning(Callback):
|
|||
) -> None:
|
||||
mapping = {p: n for n, p in pl_module.named_parameters()}
|
||||
if opt_idx not in self._internal_optimizer_metadata:
|
||||
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
|
||||
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
|
||||
current_param_groups, mapping
|
||||
)
|
||||
elif num_param_groups != len(current_param_groups):
|
||||
# save new param_groups possibly created by the users.
|
||||
self._internal_optimizer_metadata[opt_idx].extend(
|
||||
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
|
||||
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
|
||||
)
|
||||
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
|
|
@ -480,10 +480,14 @@ class Trainer(
|
|||
# default .predict() loop
|
||||
self.predict_loop = PredictionLoop()
|
||||
|
||||
# .validate() and .test() set this when they load a checkpoint
|
||||
self.validated_ckpt_path: Optional[str] = None
|
||||
self.tested_ckpt_path: Optional[str] = None
|
||||
self.predicted_ckpt_path: Optional[str] = None
|
||||
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
|
||||
self._ckpt_path: Optional[str] = None
|
||||
|
||||
# .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
|
||||
# the unified read-only `Trainer.ckpt_path` attribute in v1.8
|
||||
self._validated_ckpt_path: Optional[str] = None # TODO: remove in v1.8
|
||||
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
|
||||
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8
|
||||
|
||||
# todo: remove in v1.7
|
||||
self._weights_summary: Optional[str] = None
|
||||
|
@ -758,7 +762,10 @@ class Trainer(
|
|||
|
||||
# TODO: ckpt_path only in v2.0
|
||||
ckpt_path = ckpt_path or self.resume_from_checkpoint
|
||||
results = self._run(model, ckpt_path=ckpt_path)
|
||||
self._ckpt_path = self.__set_ckpt_path(
|
||||
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
|
||||
)
|
||||
results = self._run(model, ckpt_path=self.ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.training = False
|
||||
|
@ -837,12 +844,14 @@ class Trainer(
|
|||
# links data to the trainer
|
||||
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
|
||||
|
||||
self.validated_ckpt_path = self.__set_ckpt_path(
|
||||
self._ckpt_path = self.__set_ckpt_path(
|
||||
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
|
||||
)
|
||||
|
||||
self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8
|
||||
|
||||
# run validate
|
||||
results = self._run(model, ckpt_path=self.validated_ckpt_path)
|
||||
results = self._run(model, ckpt_path=self.ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.validating = False
|
||||
|
@ -923,12 +932,14 @@ class Trainer(
|
|||
# links data to the trainer
|
||||
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
|
||||
|
||||
self.tested_ckpt_path = self.__set_ckpt_path(
|
||||
self._ckpt_path = self.__set_ckpt_path(
|
||||
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
|
||||
)
|
||||
|
||||
self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
|
||||
|
||||
# run test
|
||||
results = self._run(model, ckpt_path=self.tested_ckpt_path)
|
||||
results = self._run(model, ckpt_path=self.ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.testing = False
|
||||
|
@ -1009,11 +1020,13 @@ class Trainer(
|
|||
# links data to the trainer
|
||||
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
|
||||
|
||||
self.predicted_ckpt_path = self.__set_ckpt_path(
|
||||
self._ckpt_path = self.__set_ckpt_path(
|
||||
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
|
||||
)
|
||||
|
||||
results = self._run(model, ckpt_path=self.predicted_ckpt_path)
|
||||
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
|
||||
|
||||
results = self._run(model, ckpt_path=self.ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.predicting = False
|
||||
|
@ -2219,6 +2232,74 @@ class Trainer(
|
|||
|
||||
return resume_from_checkpoint
|
||||
|
||||
@property
|
||||
def ckpt_path(self) -> Optional[str]:
|
||||
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
|
||||
return self._ckpt_path
|
||||
|
||||
@property
|
||||
def validated_ckpt_path(self) -> Optional[str]:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
|
||||
" `Trainer.ckpt_path` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._validated_ckpt_path
|
||||
|
||||
@validated_ckpt_path.setter
|
||||
def validated_ckpt_path(self, ckpt_path: Optional[str]) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
|
||||
" read-only `Trainer.ckpt_path`.",
|
||||
stacklevel=5,
|
||||
)
|
||||
self._validated_ckpt_path = ckpt_path
|
||||
|
||||
@property
|
||||
def tested_ckpt_path(self) -> Optional[str]:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
|
||||
" `Trainer.ckpt_path` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._tested_ckpt_path
|
||||
|
||||
@tested_ckpt_path.setter
|
||||
def tested_ckpt_path(self, ckpt_path: Optional[str]) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
|
||||
" read-only `Trainer.ckpt_path` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
self._tested_ckpt_path = ckpt_path
|
||||
|
||||
@property
|
||||
def predicted_ckpt_path(self) -> Optional[str]:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
|
||||
" `Trainer.ckpt_path` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._predicted_ckpt_path
|
||||
|
||||
@predicted_ckpt_path.setter
|
||||
def predicted_ckpt_path(self, ckpt_path: Optional[str]) -> None:
|
||||
rank_zero_deprecation(
|
||||
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
|
||||
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
|
||||
" read-only `Trainer.ckpt_path` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
self._predicted_ckpt_path = ckpt_path
|
||||
|
||||
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
|
||||
r"""
|
||||
Runs routine to create a checkpoint.
|
||||
|
|
|
@ -287,35 +287,37 @@ def test_on_before_accelerator_backend_setup(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3)
|
||||
self.act = nn.ReLU()
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.act(x)
|
||||
return self.bn(x)
|
||||
|
||||
|
||||
class ConvBlockParam(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
|
||||
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
|
||||
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.module_dict["conv"](x)
|
||||
x = self.module_dict["act"](x)
|
||||
return self.bn(x)
|
||||
|
||||
|
||||
def test_complex_nested_model():
|
||||
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
|
||||
directly themselves rather than exclusively their submodules containing parameters."""
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3)
|
||||
self.act = nn.ReLU()
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.act(x)
|
||||
return self.bn(x)
|
||||
|
||||
class ConvBlockParam(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
|
||||
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
|
||||
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.module_dict["conv"](x)
|
||||
x = self.module_dict["act"](x)
|
||||
return self.bn(x)
|
||||
|
||||
model = nn.Sequential(
|
||||
OrderedDict(
|
||||
[("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10))]
|
||||
|
|
|
@ -146,6 +146,16 @@ def test_v1_8_0_trainer_verbose_evaluate():
|
|||
trainer.verbose_evaluate = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"])
|
||||
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
|
||||
test_attr = f"{fn_prefix}_ckpt_path"
|
||||
trainer = Trainer()
|
||||
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
|
||||
_ = getattr(trainer, test_attr)
|
||||
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
|
||||
setattr(trainer, test_attr, "v")
|
||||
|
||||
|
||||
def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
|
||||
trainer = Trainer()
|
||||
with pytest.deprecated_call(
|
||||
|
|
|
@ -686,8 +686,7 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
|
|||
trainer.fit(model)
|
||||
|
||||
trainer_fn = getattr(trainer, fn)
|
||||
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
|
||||
assert getattr(trainer, path_attr) is None
|
||||
assert getattr(trainer, "ckpt_path") is None
|
||||
|
||||
if ckpt_path == "best":
|
||||
# ckpt_path is 'best', meaning we load the best weights
|
||||
|
@ -698,20 +697,20 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
|
|||
trainer_fn(model, ckpt_path=ckpt_path)
|
||||
else:
|
||||
trainer_fn(ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
|
||||
|
||||
trainer_fn(model, ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
|
||||
elif ckpt_path is None:
|
||||
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
|
||||
trainer_fn(model, ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) is None
|
||||
assert getattr(trainer, "ckpt_path") is None
|
||||
|
||||
if save_top_k > 0:
|
||||
# ckpt_path is None with no model provided means load the best weights
|
||||
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
|
||||
trainer_fn(ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
|
||||
else:
|
||||
# specific checkpoint, pick one from saved ones
|
||||
if save_top_k == 0:
|
||||
|
@ -724,10 +723,10 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
|
|||
].absolute()
|
||||
)
|
||||
trainer_fn(ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) == ckpt_path
|
||||
assert getattr(trainer, "ckpt_path") == ckpt_path
|
||||
|
||||
trainer_fn(model, ckpt_path=ckpt_path)
|
||||
assert getattr(trainer, path_attr) == ckpt_path
|
||||
assert getattr(trainer, "ckpt_path") == ckpt_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_checkpointing", (False, True))
|
||||
|
@ -758,15 +757,14 @@ def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
|
|||
trainer.fit(model)
|
||||
|
||||
trainer_fn = getattr(trainer, fn)
|
||||
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
|
||||
assert getattr(trainer, path_attr) is None
|
||||
assert getattr(trainer, "ckpt_path") is None
|
||||
|
||||
if enable_checkpointing:
|
||||
trainer_fn(ckpt_path="best")
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
|
||||
|
||||
trainer_fn(model, ckpt_path="best")
|
||||
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
|
||||
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
|
||||
else:
|
||||
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
|
||||
trainer_fn(ckpt_path="best")
|
||||
|
|
Loading…
Reference in New Issue