ddp pickle

This commit is contained in:
William Falcon 2020-04-27 07:41:30 -04:00
parent 013fd9886e
commit 5c0118fe9d
3 changed files with 22 additions and 0 deletions

View File

@ -240,6 +240,15 @@ def test_early_stopping_no_val_step(tmpdir):
assert trainer.current_epoch < trainer.max_epochs assert trainer.current_epoch < trainer.max_epochs
def test_pickling(tmpdir):
import pickle
early_stopping = EarlyStopping()
ckpt = ModelCheckpoint(tmpdir)
pickle.dumps(ckpt)
pickle.dumps(early_stopping)
def test_model_checkpoint_with_non_string_input(tmpdir): def test_model_checkpoint_with_non_string_input(tmpdir):
""" Test that None in checkpoint callback is valid and that chkp_path is """ Test that None in checkpoint callback is valid and that chkp_path is
set correctly """ set correctly """

View File

@ -77,6 +77,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
# WandbLogger, # TODO: add this one # WandbLogger, # TODO: add this one
]) ])
def test_loggers_pickle(tmpdir, monkeypatch, logger_class): def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
import pickle
"""Verify that pickling trainer with logger works.""" """Verify that pickling trainer with logger works."""
tutils.reset_seed() tutils.reset_seed()
@ -88,6 +90,9 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
logger_args = _get_logger_args(logger_class, tmpdir) logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args) logger = logger_class(**logger_args)
# test pickling loggers
pickle.dumps(logger)
trainer = Trainer( trainer = Trainer(
max_epochs=1, max_epochs=1,
logger=logger logger=logger

View File

@ -24,6 +24,14 @@ from tests.base import (
LightTestDataloader, LightTestDataloader,
LightValidationMixin, LightValidationMixin,
) )
from tests.base import TestModelBase
def test_model_pickle(tmpdir):
import pickle
model = TestModelBase()
pickle.dumps(model)
def test_hparams_save_load(tmpdir): def test_hparams_save_load(tmpdir):