Add support for returning callback from `LightningModule.configure_callbacks` (#11060)

This commit is contained in:
Rohit Gupta 2021-12-18 16:16:35 +05:30 committed by GitHub
parent 2a5d05b562
commit 3461af0ddb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 9 deletions

View File

@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
- Added support for returning a single Callback from `LightningModule.configure_callbacks` without wrapping it into a list ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060))
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))

View File

@ -21,7 +21,7 @@ import os
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
import torch
from torch import ScriptModule, Tensor
@ -31,6 +31,7 @@ from torchmetrics import Metric
from typing_extensions import Literal
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
@ -1119,15 +1120,16 @@ class LightningModule(
"""
return self(batch)
def configure_callbacks(self):
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()``
gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning
will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
run last.
Return:
A list of callbacks which will extend the list of callbacks in the Trainer.
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example::

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union
from pytorch_lightning.callbacks import (
Callback,
@ -272,6 +272,8 @@ class CallbackConnector:
model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks")
if not model_callbacks:
return
model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)

View File

@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir):
class TestModel(BoringModel):
def configure_callbacks(self):
return [model_callback_mock]
return model_callback_mock
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)