Allow access to ckpt_path within context of fit() (#11696)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Dan Dale 2022-02-04 20:23:16 -08:00 committed by GitHub
parent 7da931d1ca
commit 3bc2407239
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 145 additions and 52 deletions

View File

@ -258,6 +258,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))

View File

@ -99,7 +99,7 @@ class BaseFinetuning(Callback):
if self._restarting:
named_parameters = dict(pl_module.named_parameters())
for opt_idx, optimizer in enumerate(trainer.optimizers):
param_groups = self.__apply_mapping_to_param_groups(
param_groups = self._apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx], named_parameters
)
optimizer.param_groups = param_groups
@ -245,7 +245,7 @@ class BaseFinetuning(Callback):
self.freeze_before_training(pl_module)
@staticmethod
def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
output = []
for g in param_groups:
# skip params to save memory
@ -263,13 +263,13 @@ class BaseFinetuning(Callback):
) -> None:
mapping = {p: n for n, p in pl_module.named_parameters()}
if opt_idx not in self._internal_optimizer_metadata:
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
current_param_groups, mapping
)
elif num_param_groups != len(current_param_groups):
# save new param_groups possibly created by the users.
self._internal_optimizer_metadata[opt_idx].extend(
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
)
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

View File

@ -480,10 +480,14 @@ class Trainer(
# default .predict() loop
self.predict_loop = PredictionLoop()
# .validate() and .test() set this when they load a checkpoint
self.validated_ckpt_path: Optional[str] = None
self.tested_ckpt_path: Optional[str] = None
self.predicted_ckpt_path: Optional[str] = None
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
self._ckpt_path: Optional[str] = None
# .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
# the unified read-only `Trainer.ckpt_path` attribute in v1.8
self._validated_ckpt_path: Optional[str] = None # TODO: remove in v1.8
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8
# todo: remove in v1.7
self._weights_summary: Optional[str] = None
@ -758,7 +762,10 @@ class Trainer(
# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
results = self._run(model, ckpt_path=ckpt_path)
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.training = False
@ -837,12 +844,14 @@ class Trainer(
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
self.validated_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8
# run validate
results = self._run(model, ckpt_path=self.validated_ckpt_path)
results = self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.validating = False
@ -923,12 +932,14 @@ class Trainer(
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
self.tested_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
# run test
results = self._run(model, ckpt_path=self.tested_ckpt_path)
results = self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.testing = False
@ -1009,11 +1020,13 @@ class Trainer(
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
self.predicted_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=self.predicted_ckpt_path)
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
results = self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.predicting = False
@ -2219,6 +2232,74 @@ class Trainer(
return resume_from_checkpoint
@property
def ckpt_path(self) -> Optional[str]:
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
return self._ckpt_path
@property
def validated_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._validated_ckpt_path
@validated_ckpt_path.setter
def validated_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
" read-only `Trainer.ckpt_path`.",
stacklevel=5,
)
self._validated_ckpt_path = ckpt_path
@property
def tested_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._tested_ckpt_path
@tested_ckpt_path.setter
def tested_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
" read-only `Trainer.ckpt_path` instead.",
stacklevel=5,
)
self._tested_ckpt_path = ckpt_path
@property
def predicted_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._predicted_ckpt_path
@predicted_ckpt_path.setter
def predicted_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
" read-only `Trainer.ckpt_path` instead.",
stacklevel=5,
)
self._predicted_ckpt_path = ckpt_path
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
r"""
Runs routine to create a checkpoint.

View File

@ -287,35 +287,37 @@ def test_on_before_accelerator_backend_setup(tmpdir):
trainer.fit(model)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)
class ConvBlockParam(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)
def test_complex_nested_model():
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters."""
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)
class ConvBlockParam(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)
model = nn.Sequential(
OrderedDict(
[("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10))]

View File

@ -146,6 +146,16 @@ def test_v1_8_0_trainer_verbose_evaluate():
trainer.verbose_evaluate = False
@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"])
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
test_attr = f"{fn_prefix}_ckpt_path"
trainer = Trainer()
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
_ = getattr(trainer, test_attr)
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
setattr(trainer, test_attr, "v")
def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
trainer = Trainer()
with pytest.deprecated_call(

View File

@ -686,8 +686,7 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
trainer.fit(model)
trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None
if ckpt_path == "best":
# ckpt_path is 'best', meaning we load the best weights
@ -698,20 +697,20 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
trainer_fn(model, ckpt_path=ckpt_path)
else:
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None
if save_top_k > 0:
# ckpt_path is None with no model provided means load the best weights
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
@ -724,10 +723,10 @@ def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
].absolute()
)
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path
assert getattr(trainer, "ckpt_path") == ckpt_path
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path
assert getattr(trainer, "ckpt_path") == ckpt_path
@pytest.mark.parametrize("enable_checkpointing", (False, True))
@ -758,15 +757,14 @@ def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
trainer.fit(model)
trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None
if enable_checkpointing:
trainer_fn(ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
trainer_fn(model, ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
else:
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
trainer_fn(ckpt_path="best")