[warning] Add a warning with missing callback with resume_from_checkpoint (#7254)

* add a warning

* add changelog
This commit is contained in:
thomas chaton 2021-04-29 13:39:45 +01:00 committed by GitHub
parent e272bea4dc
commit 848288c8d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 2 deletions

View File

@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241))
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
### Changed
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

View File

@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
@ -293,10 +293,22 @@ class TrainerCallbackHookMixin(ABC):
def on_load_checkpoint(self, checkpoint):
"""Called when loading a model checkpoint."""
callback_states = checkpoint.get('callbacks')
# Todo: the `callback_states` are dropped with TPUSpawn as they
# can't be saved using `xm.save`
# https://github.com/pytorch/xla/issues/2773
callback_states = checkpoint.get('callbacks')
current_callbacks_type = {type(cb) for cb in self.callbacks}
saved_callbacks_type = set(callback_states.keys())
difference = saved_callbacks_type.difference(current_callbacks_type)
if difference:
rank_zero_warn(
"Be aware that when using ``resume_from_checkpoint``, "
"callbacks used to create the checkpoint need to be provided. "
f"Please, add the following callbacks: {list(difference)}. ", UserWarning
)
if callback_states is not None:
for callback in self.callbacks:
state = callback_states.get(type(callback))

View File

@ -2043,6 +2043,28 @@ def test_fit_test_synchronization(tmpdir):
trainer.test()
class CustomCallbackOnLoadCheckpoint(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict:
return {"a": None}
def test_on_load_checkpoint_missing_callbacks(tmpdir):
""" Test a warning appears when callbacks in the checkpoint don't match callbacks provided when resuming. """
model = BoringModel()
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, callbacks=[chk, CustomCallbackOnLoadCheckpoint()])
trainer.fit(model)
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=5, resume_from_checkpoint=chk.last_model_path, progress_bar_refresh_rate=1
)
with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"):
trainer.fit(model)
def test_module_current_fx_attributes_reset(tmpdir):
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
model = BoringModel()