Raise better error when calling `Trainer.save_checkpoint` without a model attached ()

* add error message

* add test

* changelog

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-04-26 12:16:41 +02:00 committed by GitHub
parent bb81802bff
commit ab60cdbdcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 0 deletions
CHANGELOG.md
pytorch_lightning/trainer
tests/trainer

View File

@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added a friendly error message when attempting to call `Trainer.save_checkpoint()` without a model attached ([#12772](https://github.com/PyTorchLightning/pytorch-lightning/pull/12772))
- Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715))

View File

@ -2364,6 +2364,11 @@ class Trainer(
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
"""
if self.model is None:
raise AttributeError(
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
)
self._checkpoint_connector.save_checkpoint(filepath, weights_only=weights_only, storage_options=storage_options)
"""

View File

@ -2133,3 +2133,10 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_
trainer = Trainer(**trainer_kwargs)
assert trainer.device_ids == expected_device_ids
assert trainer.num_devices == len(expected_device_ids)
def test_trainer_save_checkpoint_no_model_attached():
trainer = Trainer()
assert trainer.model is None
with pytest.raises(AttributeError, match="Saving a checkpoint is only possible if a model is attached"):
trainer.save_checkpoint("checkpoint.ckpt")