Allow Callback instance as an argument of `callbacks` in `Trainer` (#5446)

* fix

* Update CHANGELOG

* add test

* fix

* pep

* docs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Wansoo Kim 2021-01-11 20:36:32 +09:00 committed by GitHub
parent d583d56169
commit 93de5c8a40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 7 deletions

View File

@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))
### Deprecated
- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

View File

@ -14,7 +14,7 @@
import os
from typing import Union
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -41,6 +41,8 @@ class CallbackConnector:
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
# init callbacks
if isinstance(callbacks, Callback):
callbacks = [callbacks]
self.trainer.callbacks = callbacks or []
# configure checkpoint callback

View File

@ -85,7 +85,7 @@ class Trainer(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: bool = True,
callbacks: Optional[List[Callback]] = None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
process_position: int = 0,
@ -169,7 +169,7 @@ class Trainer(
benchmark: If true enables cudnn.benchmark.
callbacks: Add a list of callbacks.
callbacks: Add a callback or list of callbacks.
checkpoint_callback: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in

View File

@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer, Callback
import pytest
from pytorch_lightning import Callback, Trainer
from tests.base.boring_model import BoringModel
def test_train_step_no_return(tmpdir):
@pytest.mark.parametrize("single_cb", [False, True])
def test_train_step_no_return(tmpdir, single_cb):
"""
Tests that only training_step can be used
"""
@ -53,7 +56,7 @@ def test_train_step_no_return(tmpdir):
model = TestModel()
trainer = Trainer(
callbacks=[CB()],
callbacks=CB() if single_cb else [CB()],
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
@ -62,4 +65,7 @@ def test_train_step_no_return(tmpdir):
weights_summary=None,
)
trainer.fit(model)
assert any(isinstance(c, CB) for c in trainer.callbacks)
results = trainer.fit(model)
assert results