From 93de5c8a40870ccb30026bf0499d89a0c5a03d21 Mon Sep 17 00:00:00 2001 From: Wansoo Kim Date: Mon, 11 Jan 2021 20:36:32 +0900 Subject: [PATCH] Allow Callback instance as an argument of `callbacks` in `Trainer` (#5446) * fix * Update CHANGELOG * add test * fix * pep * docs Co-authored-by: Jirka Borovec Co-authored-by: Rohit Gupta --- CHANGELOG.md | 3 +++ .../trainer/connectors/callback_connector.py | 4 +++- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/callbacks/test_callback_hook_outputs.py | 14 ++++++++++---- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3983e3416e..eb1963d64d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446)) + + ### Deprecated - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 5277485e95..72a0641a08 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -14,7 +14,7 @@ import os from typing import Union -from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -41,6 +41,8 @@ class CallbackConnector: self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir # init callbacks + if isinstance(callbacks, Callback): + callbacks = [callbacks] self.trainer.callbacks = callbacks or [] # configure checkpoint callback diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b99ee451ea..b923ae9adc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -85,7 +85,7 @@ class Trainer( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: bool = True, - callbacks: Optional[List[Callback]] = None, + callbacks: Optional[Union[List[Callback], Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, @@ -169,7 +169,7 @@ class Trainer( benchmark: If true enables cudnn.benchmark. - callbacks: Add a list of callbacks. + callbacks: Add a callback or list of callbacks. checkpoint_callback: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 406f0c8253..d5538b5617 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer, Callback +import pytest + +from pytorch_lightning import Callback, Trainer from tests.base.boring_model import BoringModel -def test_train_step_no_return(tmpdir): +@pytest.mark.parametrize("single_cb", [False, True]) +def test_train_step_no_return(tmpdir, single_cb): """ Tests that only training_step can be used """ @@ -53,7 +56,7 @@ def test_train_step_no_return(tmpdir): model = TestModel() trainer = Trainer( - callbacks=[CB()], + callbacks=CB() if single_cb else [CB()], default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, @@ -62,4 +65,7 @@ def test_train_step_no_return(tmpdir): weights_summary=None, ) - trainer.fit(model) + assert any(isinstance(c, CB) for c in trainer.callbacks) + + results = trainer.fit(model) + assert results