ddp pickle
This commit is contained in:
parent
013fd9886e
commit
5c0118fe9d
|
@ -240,6 +240,15 @@ def test_early_stopping_no_val_step(tmpdir):
|
|||
assert trainer.current_epoch < trainer.max_epochs
|
||||
|
||||
|
||||
def test_pickling(tmpdir):
|
||||
import pickle
|
||||
early_stopping = EarlyStopping()
|
||||
ckpt = ModelCheckpoint(tmpdir)
|
||||
|
||||
pickle.dumps(ckpt)
|
||||
pickle.dumps(early_stopping)
|
||||
|
||||
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir):
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is
|
||||
set correctly """
|
||||
|
|
|
@ -77,6 +77,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
|||
# WandbLogger, # TODO: add this one
|
||||
])
|
||||
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
||||
import pickle
|
||||
|
||||
"""Verify that pickling trainer with logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
|
@ -88,6 +90,9 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
|||
logger_args = _get_logger_args(logger_class, tmpdir)
|
||||
logger = logger_class(**logger_args)
|
||||
|
||||
# test pickling loggers
|
||||
pickle.dumps(logger)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
logger=logger
|
||||
|
|
|
@ -24,6 +24,14 @@ from tests.base import (
|
|||
LightTestDataloader,
|
||||
LightValidationMixin,
|
||||
)
|
||||
from tests.base import TestModelBase
|
||||
|
||||
|
||||
def test_model_pickle(tmpdir):
|
||||
import pickle
|
||||
|
||||
model = TestModelBase()
|
||||
pickle.dumps(model)
|
||||
|
||||
|
||||
def test_hparams_save_load(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue