2020-06-29 01:36:46 +00:00
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import platform
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import cloudpickle
|
|
|
|
import pytest
|
2020-08-08 10:02:43 +00:00
|
|
|
import torch
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
import tests.base.develop_utils as tutils
|
2020-08-08 10:02:43 +00:00
|
|
|
from pytorch_lightning import Trainer, seed_everything
|
2020-06-29 01:36:46 +00:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
|
|
|
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
2020-07-24 12:15:40 +00:00
|
|
|
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
2020-06-29 01:36:46 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
|
|
|
|
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
|
|
|
|
|
2020-07-31 09:18:32 +00:00
|
|
|
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
|
2020-06-29 01:36:46 +00:00
|
|
|
trainer.fit(model)
|
2020-07-31 09:18:32 +00:00
|
|
|
assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints'
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2020-07-31 09:18:32 +00:00
|
|
|
'logger_version,expected', [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
|
2020-06-29 01:36:46 +00:00
|
|
|
)
|
|
|
|
def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
|
|
|
"""Test that "version_" prefix is only added when logger's version is an integer"""
|
|
|
|
tutils.reset_seed()
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
|
|
|
|
|
2020-07-31 09:18:32 +00:00
|
|
|
trainer = Trainer(default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger)
|
2020-06-29 01:36:46 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
|
2020-07-27 16:53:11 +00:00
|
|
|
ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
|
2020-06-29 01:36:46 +00:00
|
|
|
assert ckpt_version == expected
|
|
|
|
|
|
|
|
|
|
|
|
def test_pickling(tmpdir):
|
|
|
|
ckpt = ModelCheckpoint(tmpdir)
|
|
|
|
|
|
|
|
ckpt_pickled = pickle.dumps(ckpt)
|
|
|
|
ckpt_loaded = pickle.loads(ckpt_pickled)
|
|
|
|
assert vars(ckpt) == vars(ckpt_loaded)
|
|
|
|
|
|
|
|
ckpt_pickled = cloudpickle.dumps(ckpt)
|
|
|
|
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
|
|
|
|
assert vars(ckpt) == vars(ckpt_loaded)
|
|
|
|
|
|
|
|
|
|
|
|
class ModelCheckpointTestInvocations(ModelCheckpoint):
|
|
|
|
# this class has to be defined outside the test function, otherwise we get pickle error
|
|
|
|
# due to the way ddp process is launched
|
|
|
|
|
|
|
|
def __init__(self, expected_count, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.count = 0
|
|
|
|
self.expected_count = expected_count
|
|
|
|
|
2020-07-20 23:00:20 +00:00
|
|
|
def _save_model(self, filepath, trainer, pl_module):
|
2020-06-29 01:36:46 +00:00
|
|
|
# make sure we don't save twice
|
|
|
|
assert not os.path.isfile(filepath)
|
|
|
|
self.count += 1
|
2020-07-20 23:00:20 +00:00
|
|
|
super()._save_model(filepath, trainer, pl_module)
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
|
|
super().on_train_end(trainer, pl_module)
|
|
|
|
# on rank 0 we expect the saved files and on all others no saves
|
2020-07-31 09:18:32 +00:00
|
|
|
assert (trainer.global_rank == 0 and self.count == self.expected_count) or (
|
|
|
|
trainer.global_rank > 0 and self.count == 0
|
|
|
|
)
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
|
|
|
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
|
|
|
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
num_epochs = 4
|
|
|
|
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
|
|
|
|
trainer = Trainer(
|
|
|
|
distributed_backend='ddp_cpu',
|
|
|
|
num_processes=2,
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
early_stop_callback=False,
|
|
|
|
checkpoint_callback=model_checkpoint,
|
|
|
|
max_epochs=num_epochs,
|
|
|
|
)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
assert 1 == result
|
2020-08-08 10:02:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|
|
|
""" Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """
|
|
|
|
seed_everything(100)
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
num_epochs = 3
|
|
|
|
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
early_stop_callback=False,
|
|
|
|
checkpoint_callback=model_checkpoint,
|
|
|
|
max_epochs=num_epochs,
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
|
|
|
|
path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt
|
|
|
|
assert path_last_epoch != path_last
|
|
|
|
ckpt_last_epoch = torch.load(path_last_epoch)
|
|
|
|
ckpt_last = torch.load(path_last)
|
|
|
|
matching_keys = (
|
|
|
|
"epoch",
|
|
|
|
"global_step",
|
|
|
|
ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE,
|
|
|
|
ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH,
|
|
|
|
)
|
|
|
|
for key in matching_keys:
|
|
|
|
assert ckpt_last_epoch[key] == ckpt_last[key]
|
|
|
|
|
|
|
|
# it is easier to load the model objects than to iterate over the raw dict of tensors
|
|
|
|
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
|
|
|
|
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
|
|
|
|
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
|
|
|
|
assert w0.eq(w1).all()
|