diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 4731d43516..c6c36ca5f5 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -240,6 +240,15 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs +def test_pickling(tmpdir): + import pickle + early_stopping = EarlyStopping() + ckpt = ModelCheckpoint(tmpdir) + + pickle.dumps(ckpt) + pickle.dumps(early_stopping) + + def test_model_checkpoint_with_non_string_input(tmpdir): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index d9bb804bb4..857512042b 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -77,6 +77,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class): # WandbLogger, # TODO: add this one ]) def test_loggers_pickle(tmpdir, monkeypatch, logger_class): + import pickle + """Verify that pickling trainer with logger works.""" tutils.reset_seed() @@ -88,6 +90,9 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) + # test pickling loggers + pickle.dumps(logger) + trainer = Trainer( max_epochs=1, logger=logger diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6876a69383..9d8217f9cd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -24,6 +24,14 @@ from tests.base import ( LightTestDataloader, LightValidationMixin, ) +from tests.base import TestModelBase + + +def test_model_pickle(tmpdir): + import pickle + + model = TestModelBase() + pickle.dumps(model) def test_hparams_save_load(tmpdir):