Allow Callback instance as an argument of `callbacks` in `Trainer` (#5446)
* fix * Update CHANGELOG * add test * fix * pep * docs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
d583d56169
commit
93de5c8a40
|
@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
|
||||
|
||||
|
||||
- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -41,6 +41,8 @@ class CallbackConnector:
|
|||
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
|
||||
|
||||
# init callbacks
|
||||
if isinstance(callbacks, Callback):
|
||||
callbacks = [callbacks]
|
||||
self.trainer.callbacks = callbacks or []
|
||||
|
||||
# configure checkpoint callback
|
||||
|
|
|
@ -85,7 +85,7 @@ class Trainer(
|
|||
self,
|
||||
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
|
||||
checkpoint_callback: bool = True,
|
||||
callbacks: Optional[List[Callback]] = None,
|
||||
callbacks: Optional[Union[List[Callback], Callback]] = None,
|
||||
default_root_dir: Optional[str] = None,
|
||||
gradient_clip_val: float = 0,
|
||||
process_position: int = 0,
|
||||
|
@ -169,7 +169,7 @@ class Trainer(
|
|||
|
||||
benchmark: If true enables cudnn.benchmark.
|
||||
|
||||
callbacks: Add a list of callbacks.
|
||||
callbacks: Add a callback or list of callbacks.
|
||||
|
||||
checkpoint_callback: If ``True``, enable checkpointing.
|
||||
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
|
||||
|
|
|
@ -11,11 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_train_step_no_return(tmpdir):
|
||||
@pytest.mark.parametrize("single_cb", [False, True])
|
||||
def test_train_step_no_return(tmpdir, single_cb):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
|
@ -53,7 +56,7 @@ def test_train_step_no_return(tmpdir):
|
|||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[CB()],
|
||||
callbacks=CB() if single_cb else [CB()],
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
|
@ -62,4 +65,7 @@ def test_train_step_no_return(tmpdir):
|
|||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
assert any(isinstance(c, CB) for c in trainer.callbacks)
|
||||
|
||||
results = trainer.fit(model)
|
||||
assert results
|
||||
|
|
Loading…
Reference in New Issue