Add support for returning callback from `LightningModule.configure_callbacks` (#11060)
This commit is contained in:
parent
2a5d05b562
commit
3461af0ddb
|
@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
|
||||
|
||||
|
||||
- Added support for returning a single Callback from `LightningModule.configure_callbacks` without wrapping it into a list ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060))
|
||||
|
||||
|
||||
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import os
|
|||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import ScriptModule, Tensor
|
||||
|
@ -31,6 +31,7 @@ from torchmetrics import Metric
|
|||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks.progress import base as progress_base
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
|
||||
|
@ -1119,15 +1120,16 @@ class LightningModule(
|
|||
"""
|
||||
return self(batch)
|
||||
|
||||
def configure_callbacks(self):
|
||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()``
|
||||
gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's
|
||||
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
|
||||
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning
|
||||
will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
|
||||
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
|
||||
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
|
||||
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
|
||||
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
|
||||
run last.
|
||||
|
||||
Return:
|
||||
A list of callbacks which will extend the list of callbacks in the Trainer.
|
||||
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
|
||||
|
||||
Example::
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
from pytorch_lightning.callbacks import (
|
||||
Callback,
|
||||
|
@ -272,6 +272,8 @@ class CallbackConnector:
|
|||
model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks")
|
||||
if not model_callbacks:
|
||||
return
|
||||
|
||||
model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks
|
||||
model_callback_types = {type(c) for c in model_callbacks}
|
||||
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
|
||||
override_types = model_callback_types.intersection(trainer_callback_types)
|
||||
|
|
|
@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir):
|
|||
|
||||
class TestModel(BoringModel):
|
||||
def configure_callbacks(self):
|
||||
return [model_callback_mock]
|
||||
return model_callback_mock
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)
|
||||
|
|
Loading…
Reference in New Issue