[warning] Add a warning with missing callback with resume_from_checkpoint (#7254)
* add a warning * add changelog
This commit is contained in:
parent
e272bea4dc
commit
848288c8d8
|
@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241))
|
||||
|
||||
|
||||
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
|||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
@ -293,10 +293,22 @@ class TrainerCallbackHookMixin(ABC):
|
|||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""Called when loading a model checkpoint."""
|
||||
callback_states = checkpoint.get('callbacks')
|
||||
|
||||
# Todo: the `callback_states` are dropped with TPUSpawn as they
|
||||
# can't be saved using `xm.save`
|
||||
# https://github.com/pytorch/xla/issues/2773
|
||||
callback_states = checkpoint.get('callbacks')
|
||||
|
||||
current_callbacks_type = {type(cb) for cb in self.callbacks}
|
||||
saved_callbacks_type = set(callback_states.keys())
|
||||
difference = saved_callbacks_type.difference(current_callbacks_type)
|
||||
if difference:
|
||||
rank_zero_warn(
|
||||
"Be aware that when using ``resume_from_checkpoint``, "
|
||||
"callbacks used to create the checkpoint need to be provided. "
|
||||
f"Please, add the following callbacks: {list(difference)}. ", UserWarning
|
||||
)
|
||||
|
||||
if callback_states is not None:
|
||||
for callback in self.callbacks:
|
||||
state = callback_states.get(type(callback))
|
||||
|
|
|
@ -2043,6 +2043,28 @@ def test_fit_test_synchronization(tmpdir):
|
|||
trainer.test()
|
||||
|
||||
|
||||
class CustomCallbackOnLoadCheckpoint(Callback):
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict:
|
||||
return {"a": None}
|
||||
|
||||
|
||||
def test_on_load_checkpoint_missing_callbacks(tmpdir):
|
||||
""" Test a warning appears when callbacks in the checkpoint don't match callbacks provided when resuming. """
|
||||
|
||||
model = BoringModel()
|
||||
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, callbacks=[chk, CustomCallbackOnLoadCheckpoint()])
|
||||
trainer.fit(model)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, max_epochs=5, resume_from_checkpoint=chk.last_model_path, progress_bar_refresh_rate=1
|
||||
)
|
||||
with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_module_current_fx_attributes_reset(tmpdir):
|
||||
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
|
||||
model = BoringModel()
|
||||
|
|
Loading…
Reference in New Issue