ddp pickle
This commit is contained in:
parent
013fd9886e
commit
5c0118fe9d
|
@ -240,6 +240,15 @@ def test_early_stopping_no_val_step(tmpdir):
|
||||||
assert trainer.current_epoch < trainer.max_epochs
|
assert trainer.current_epoch < trainer.max_epochs
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickling(tmpdir):
|
||||||
|
import pickle
|
||||||
|
early_stopping = EarlyStopping()
|
||||||
|
ckpt = ModelCheckpoint(tmpdir)
|
||||||
|
|
||||||
|
pickle.dumps(ckpt)
|
||||||
|
pickle.dumps(early_stopping)
|
||||||
|
|
||||||
|
|
||||||
def test_model_checkpoint_with_non_string_input(tmpdir):
|
def test_model_checkpoint_with_non_string_input(tmpdir):
|
||||||
""" Test that None in checkpoint callback is valid and that chkp_path is
|
""" Test that None in checkpoint callback is valid and that chkp_path is
|
||||||
set correctly """
|
set correctly """
|
||||||
|
|
|
@ -77,6 +77,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
||||||
# WandbLogger, # TODO: add this one
|
# WandbLogger, # TODO: add this one
|
||||||
])
|
])
|
||||||
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
||||||
|
import pickle
|
||||||
|
|
||||||
"""Verify that pickling trainer with logger works."""
|
"""Verify that pickling trainer with logger works."""
|
||||||
tutils.reset_seed()
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
@ -88,6 +90,9 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
||||||
logger_args = _get_logger_args(logger_class, tmpdir)
|
logger_args = _get_logger_args(logger_class, tmpdir)
|
||||||
logger = logger_class(**logger_args)
|
logger = logger_class(**logger_args)
|
||||||
|
|
||||||
|
# test pickling loggers
|
||||||
|
pickle.dumps(logger)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
max_epochs=1,
|
max_epochs=1,
|
||||||
logger=logger
|
logger=logger
|
||||||
|
|
|
@ -24,6 +24,14 @@ from tests.base import (
|
||||||
LightTestDataloader,
|
LightTestDataloader,
|
||||||
LightValidationMixin,
|
LightValidationMixin,
|
||||||
)
|
)
|
||||||
|
from tests.base import TestModelBase
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_pickle(tmpdir):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
model = TestModelBase()
|
||||||
|
pickle.dumps(model)
|
||||||
|
|
||||||
|
|
||||||
def test_hparams_save_load(tmpdir):
|
def test_hparams_save_load(tmpdir):
|
||||||
|
|
Loading…
Reference in New Issue