lightning/tests/checkpointing/test_model_checkpoint.py

1311 lines
47 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import pickle
import re
import time
from argparse import Namespace
from datetime import timedelta
from logging import INFO
from pathlib import Path
from typing import Union
from unittest import mock
from unittest.mock import Mock
import cloudpickle
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim
import pytorch_lightning as pl
import tests.helpers.utils as tutils
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
class LogInTwoMethods(BoringModel):
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
self.log('early_stop_on', out['loss'])
return out
def validation_epoch_end(self, outputs):
outs = torch.stack([x['x'] for x in outputs]).mean()
self.log('val_acc', outs)
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"validation_step_none,val_dataloaders_none,monitor",
[
(False, False, 'val_log'),
(True, False, 'train_log_epoch'),
(False, True, 'val_log'),
],
)
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
def test_model_checkpoint_score_and_ckpt(
tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool
):
"""
Test that when a model checkpoint is saved, it saves with
the correct score appended to ckpt_path and checkpoint data
"""
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr, gamma = 1e-1, 2
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.train_log_epochs = torch.randn(max_epochs, limit_train_batches)
self.val_logs = torch.randn(max_epochs, limit_val_batches)
def training_step(self, batch, batch_idx):
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
self.log('train_log', log_value, on_epoch=True)
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.current_epoch, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=lr)
if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
return [optimizer], [lr_scheduler]
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
model = CustomBoringModel()
if validation_step_none:
model.validation_step = None
if val_dataloaders_none:
model.val_dataloaders = None
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == max_epochs
assert len(lr_scheduler_debug) == max_epochs
for epoch in range(max_epochs):
score = scores[epoch]
expected_score = getattr(model, f'{monitor}s')[epoch].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == limit_train_batches * (epoch + 1)
mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
if not reduce_lr_on_plateau:
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
# if validation_step_none, the checkpoint gets saved after the learning rate update
# so we need to increase the count by one
assert actual_step_count == epoch + 1 + validation_step_none
assert actual_lr == lr * gamma**(epoch + validation_step_none)
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau,epoch_aligned",
[
(0.25, True, True),
(0.25, False, True),
(0.42, False, False),
],
)
def test_model_checkpoint_score_and_ckpt_val_check_interval(
tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned
):
"""
Test that when a model checkpoint is saved, it saves with the correct
score appended to ckpt_path and checkpoint data with val_check_interval
"""
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr, gamma = 1e-1, 2
monitor = 'val_log'
per_val_train_batches = int(limit_train_batches * val_check_interval)
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)
class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
self.val_loop_count = 0
def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
return super().validation_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
self.val_loop_count += 1
super().validation_epoch_end(outputs)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=lr)
if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
return [optimizer], [lr_scheduler]
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
model = CustomBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = False, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = True
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
assert len(lr_scheduler_debug) == max_epochs
def _make_assertions(epoch, ix, version=''):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
epoch_num = epoch + duplicated
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk['global_step'] == expected_global_step
mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
if not reduce_lr_on_plateau:
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
assert actual_step_count == epoch + 1 + duplicated
assert actual_lr == lr * gamma**(epoch + duplicated)
return score
for epoch in range(max_epochs):
for i in range(per_epoch_val_checks):
score = _make_assertions(epoch, i)
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
# check the ckpt file saved on_train_end
if additional_ckpt_path:
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):
"""Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly"""
tutils.reset_seed()
model = LogInTwoMethods()
checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}', save_top_k=save_top_k)
max_epochs = 2
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
overfit_batches=0.20,
max_epochs=max_epochs,
)
trainer.fit(model)
assert (checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints")
if save_top_k == -1:
ckpt_files = os.listdir(checkpoint.dirpath)
expected_ckpt_files = [f'epoch={i}.ckpt' for i in range(max_epochs)]
assert len(ckpt_files) == len(expected_ckpt_files) == max_epochs
assert set(ckpt_files) == set(expected_ckpt_files)
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = LogInTwoMethods()
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2)
trainer.fit(model)
path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
checkpoint.to_yaml(path_yaml)
d = yaml.full_load(open(path_yaml, 'r'))
best_k = {k: v for k, v in checkpoint.best_k_models.items()}
assert d == best_k
@pytest.mark.parametrize(
"logger_version,expected",
[(None, "version_0"), (1, "version_1"), ("awesome", "awesome")],
)
def test_model_checkpoint_path(tmpdir, logger_version: Union[None, int, str], expected: str):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = LogInTwoMethods()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=0.2,
max_epochs=2,
logger=logger,
)
trainer.fit(model)
ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
assert ckpt_version == expected
def test_pickling(tmpdir):
ckpt = ModelCheckpoint(dirpath=tmpdir)
ckpt_pickled = pickle.dumps(ckpt)
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
ckpt_pickled = cloudpickle.dumps(ckpt)
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
class ModelCheckpointTestInvocations(ModelCheckpoint):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched
def __init__(self, expected_count, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_count = expected_count
self.on_save_checkpoint_count = 0
def on_train_start(self, trainer, pl_module):
torch.save = Mock(wraps=torch.save)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# only rank 0 will call ``torch.save``
super().on_save_checkpoint(trainer, pl_module, checkpoint)
self.on_save_checkpoint_count += 1
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
assert self.best_model_path
assert self.best_model_score
assert self.on_save_checkpoint_count == self.expected_count
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
else:
assert torch.save.call_count == 0
@RunIf(skip_windows=True)
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = LogInTwoMethods()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1)
trainer = Trainer(
accelerator="ddp_cpu",
num_processes=2,
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=num_epochs,
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_model_checkpoint_format_checkpoint_name(tmpdir):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name('', {'epoch': 3, 'step': 2})
assert ckpt_name == 'epoch=3-step=2'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {'epoch': 3, 'step': 2}, prefix='test')
assert ckpt_name == 'test-epoch=3-step=2'
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', {}, prefix='test')
assert ckpt_name == 'test-ckpt'
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', {'epoch': 3, 'acc': 0.03})
assert ckpt_name == 'epoch=003-acc=0.03'
# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', {'epoch': 3, 'acc': 0.03}, prefix='test')
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no dirpath set
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name({'epoch': 3, 'step': 2})
assert ckpt_name == 'epoch=3-step=2.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name({'epoch': 5, 'step': 4})
assert ckpt_name == 'epoch=5-step=4.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name({'epoch': 3, 'step': 4})
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')
# with version
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name')
ckpt_name = ckpt.format_checkpoint_name({}, ver=3)
assert ckpt_name == tmpdir / 'name-v3.ckpt'
# using slashes
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}')
ckpt_name = ckpt.format_checkpoint_name({'epoch': 4, 'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'
# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}', {
'epoch': 3,
'val/acc': 0.03
}, auto_insert_metric_name=False
)
assert ckpt_name == 'epoch=003-val_acc=0.03'
class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
def test_model_checkpoint_file_extension(tmpdir):
"""
Test ModelCheckpoint with different file extension.
"""
model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(
monitor='early_stop_on',
dirpath=tmpdir,
save_top_k=1,
save_last=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_steps=1,
logger=False,
)
trainer.fit(model)
expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
assert set(expected) == set(os.listdir(tmpdir))
def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = LogInTwoMethods()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=epochs,
limit_train_batches=10,
limit_val_batches=10,
logger=False,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, {'epoch': trainer.current_epoch}
)
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set([f"epoch={i}-step={j}.ckpt"
for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_invalid_top_k(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative save_top_k argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be None or >= -1'):
ModelCheckpoint(dirpath=tmpdir, save_top_k=-3)
def test_none_monitor_top_k(tmpdir):
""" Test that a warning appears for positive top_k with monitor=None. """
with pytest.raises(
MisconfigurationException, match=r'ModelCheckpoint\(save_top_k=3, monitor=None\) is not a valid*'
):
ModelCheckpoint(dirpath=tmpdir, save_top_k=3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, save_top_k=None)
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)
def test_none_monitor_save_last(tmpdir):
""" Test that a warning appears for save_last=True with monitor=None. """
with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'):
ModelCheckpoint(dirpath=tmpdir, save_last=True)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, save_last=None)
ModelCheckpoint(dirpath=tmpdir, save_last=False)
def test_invalid_every_n_val_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
def test_invalid_every_n_train_steps(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
def test_invalid_trigger_combination(tmpdir):
"""
Test that a MisconfigurationException is raised if more than one of
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
"""
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
ModelCheckpoint(
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
)
def test_none_every_n_train_steps_val_epochs(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
assert checkpoint_callback.period == 1
assert checkpoint_callback._every_n_val_epochs == 1
assert checkpoint_callback._every_n_train_steps == 0
def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
model = LogInTwoMethods()
epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
limit_train_batches=10,
limit_val_batches=10,
max_epochs=epochs,
logger=False,
)
with caplog.at_level(INFO):
trainer.fit(model)
assert "will duplicate the last checkpoint saved" in caplog.text
# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == tmpdir / 'epoch=1-step=19.ckpt'
assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt'
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''
# check that the correct ckpts were created
expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])]
expected.append('last.ckpt')
assert set(os.listdir(tmpdir)) == set(expected)
@pytest.mark.parametrize("period", list(range(4)))
def test_model_checkpoint_period(tmpdir, period: int):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)
def test_ckpt_every_n_train_steps(tmpdir):
""" Tests that the checkpoints are saved every n training steps. """
model = LogInTwoMethods()
every_n_train_steps = 16
max_epochs = 2
epoch_length = 64
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
progress_bar_refresh_rate=0,
callbacks=[checkpoint_callback],
logger=False,
)
trainer.fit(model)
expected = [
f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps)
]
assert set(os.listdir(tmpdir)) == set(expected)
@mock.patch("pytorch_lightning.callbacks.model_checkpoint.time")
def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None:
"""Tests that the checkpoints are saved at the specified time interval."""
seconds_per_batch = 7
start_time = time.monotonic()
batches_per_epoch = 64
num_epochs = 2
max_batches = batches_per_epoch * num_epochs + 1
mock_datetime.monotonic.side_effect = [start_time + seconds_per_batch * i for i in range(max_batches)]
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=num_epochs,
max_epochs=num_epochs,
progress_bar_refresh_rate=0,
callbacks=[
ModelCheckpoint(
filename="{epoch}-{step}",
dirpath=tmpdir,
train_time_interval=timedelta(minutes=1),
save_top_k=-1,
save_last=False,
)
],
logger=False,
)
trainer.fit(model)
# Each batch takes 7 sec and we checkpoint every minute. There are 64
# batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes,
# so we should have 14 checkpoints.
assert len(os.listdir(tmpdir)) == 14
def test_model_checkpoint_topk_zero(tmpdir):
""" Test that no checkpoints are saved when save_top_k=0. """
model = LogInTwoMethods()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=2,
logger=False,
)
trainer.fit(model)
# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == ''
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''
# check that only the last ckpt was created
assert os.listdir(tmpdir) == ['last.ckpt']
assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt'
def test_model_checkpoint_topk_all(tmpdir):
""" Test that save_top_k=-1 tracks the best models when monitor key is provided. """
seed_everything(1000)
epochs = 3
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename="{epoch}",
monitor="epoch",
mode='max',
save_top_k=-1,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
logger=False,
val_check_interval=1.0,
)
trainer.fit(model)
assert checkpoint_callback.monitor == 'epoch'
assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt"
assert checkpoint_callback.best_model_score == epochs - 1
assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs
assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs))
assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt'
def test_ckpt_metric_names(tmpdir):
model = LogInTwoMethods()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename="{val_loss:.2f}")],
)
trainer.fit(model)
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if "val_loss" in x]
assert len(ckpts) == 1
val = re.sub("[^0-9.]", "", ckpts[0])
assert len(val) > 3
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_default_checkpoint_behavior(tmpdir):
seed_everything(1234)
model = LogInTwoMethods()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
progress_bar_refresh_rate=0,
limit_train_batches=5,
limit_val_batches=5,
)
trainer.fit(model)
results = trainer.test()
assert len(results) == 1
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2-step=14.ckpt'
@pytest.mark.parametrize('max_epochs', [1, 2])
@pytest.mark.parametrize('should_validate', [True, False])
@pytest.mark.parametrize('save_last', [True, False])
@pytest.mark.parametrize('verbose', [True, False])
def test_model_checkpoint_save_last_warning(
tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool
):
"""Tests 'Saving latest checkpoint...' log"""
model = LogInTwoMethods()
if not should_validate:
model.validation_step = None
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
""" Tests that the save_last checkpoint contains the latest information. """
seed_everything(100)
model = LogInTwoMethods()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=num_epochs,
)
trainer.fit(model)
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
ch_type = type(model_checkpoint)
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
model_last = LogInTwoMethods.load_from_checkpoint(model_checkpoint.last_model_path)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize('mode', ['min', 'max'])
def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
monitor = [float('nan')]
monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5]
class CurrentModel(LogInTwoMethods):
def validation_epoch_end(self, outputs):
val_loss = monitor[self.current_epoch]
self.log('abc', val_loss)
model = CurrentModel()
trainer = Trainer(
callbacks=[ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, dirpath=tmpdir)],
default_root_dir=tmpdir,
val_check_interval=1.0,
max_epochs=len(monitor),
)
trainer.fit(model)
# check that last one is also the best one
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1
def test_checkpoint_repeated_strategy(tmpdir):
"""
This test validates that the checkpoint can be called when provided to callbacks list
"""
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}")
class ExtendedBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("val_loss", loss)
model = ExtendedBoringModel()
model.validation_epoch_end = None
trainer = Trainer(
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
callbacks=[checkpoint_callback],
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
for idx in range(4):
# load from checkpoint
model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path)
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=checkpoint_callback.best_model_path,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
trainer.test(model, verbose=False)
assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'}
assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)}
def test_checkpoint_repeated_strategy_extended(tmpdir):
"""
This test validates checkpoint can be called several times without
increasing internally its global step if nothing run.
"""
class ExtendedBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}
def validation_epoch_end(self, *_):
...
def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0
def get_last_checkpoint(ckpt_dir):
last = ckpt_dir.listdir(sort=True)[-1]
return str(last)
def assert_checkpoint_content(ckpt_dir):
chk = pl_load(get_last_checkpoint(ckpt_dir))
assert chk["epoch"] == epochs
assert chk["global_step"] == 4
def assert_checkpoint_log_dir(idx):
lightning_logs = tmpdir / 'lightning_logs'
actual = [d.basename for d in lightning_logs.listdir(sort=True)]
assert actual == [f'version_{i}' for i in range(idx + 1)]
assert len(ckpt_dir.listdir()) == epochs
ckpt_dir = tmpdir / 'checkpoints'
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
epochs = 2
limit_train_batches = 2
trainer_config = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
limit_train_batches=limit_train_batches,
limit_val_batches=3,
limit_test_batches=4,
callbacks=[checkpoint_cb],
)
trainer = pl.Trainer(**trainer_config)
assert_trainer_init(trainer)
model = ExtendedBoringModel()
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)
assert_checkpoint_content(ckpt_dir)
trainer.validate(model)
assert trainer.current_epoch == epochs - 1
trainer.test(model)
assert trainer.current_epoch == epochs - 1
for idx in range(1, 5):
chk = get_last_checkpoint(ckpt_dir)
assert_checkpoint_content(ckpt_dir)
# load from checkpoint
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk)
assert_trainer_init(trainer)
model = ExtendedBoringModel()
trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
# resume_from_checkpoint is resumed when calling `.fit`
assert trainer.global_step == 0
assert trainer.current_epoch == 0
trainer.fit(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
assert_checkpoint_log_dir(idx)
trainer.validate(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
def test_configure_model_checkpoint(tmpdir):
""" Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """
kwargs = dict(default_root_dir=tmpdir)
callback1 = ModelCheckpoint()
callback2 = ModelCheckpoint()
# no callbacks
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
assert trainer.checkpoint_callback is None
# default configuration
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)
# custom callback passed to callbacks list, checkpoint_callback=True is ignored
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1
# multiple checkpoint callbacks
trainer = Trainer(callbacks=[callback1, callback2], **kwargs)
assert trainer.checkpoint_callback == callback1
assert trainer.checkpoint_callbacks == [callback1, callback2]
with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
def test_val_check_interval_checkpoint_files(tmpdir):
""" Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """
model = LogInTwoMethods()
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=-1,
monitor="val_acc",
mode="max",
)
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=0.2,
max_epochs=1,
limit_train_batches=10,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
files = {p.basename for p in tmpdir.listdir()}
assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]}
def test_current_score(tmpdir):
""" Check that the current_score value is correct and was saved """
class TestModel(BoringModel):
def training_step(self, *args):
self.log("foo", (self.current_epoch + 1) / 10)
return super().training_step(*args)
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=3,
monitor="foo",
mode="min",
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]
@pytest.mark.parametrize("mode", ["min", "max"])
def test_current_score_when_nan(tmpdir, mode: str):
""" Check that ModelCheckpoint handles NaN values correctly """
class TestModel(BoringModel):
def training_step(self, *args):
self.log("foo", float("nan"))
return super().training_step(*args)
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=1,
monitor="foo",
mode=mode,
)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(TestModel())
expected = float("inf" if mode == "min" else "-inf")
assert model_checkpoint.best_model_score == expected
assert model_checkpoint.current_score == expected
@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=1,
monitor="foo",
)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
"""
Check that previous checkpoints are renamed to have the correct
version suffix when new trainer instances are used
"""
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
trainer = Trainer(
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
# check best_k_models state
expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"}
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
# check created ckpts
assert set(f.basename for f in tmpdir.listdir()) == {
"epoch=0.ckpt",
"epoch=1.ckpt",
"epoch=0-v1.ckpt",
"epoch=1-v1.ckpt",
}
def test_ckpt_version_after_rerun_same_trainer(tmpdir):
"""
Check that previous checkpoints are renamed to have the correct
version suffix when the same trainer instance is used
"""
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
mc.STARTING_VERSION = 9
trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
trainer.train_loop.max_epochs = 4
trainer.fit(BoringModel())
ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]}
# check best_k_models state
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
# check created ckpts
assert set(os.listdir(tmpdir)) == expected
def test_model_checkpoint_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"):
ModelCheckpoint(mode="unknown_option")
def test_trainer_checkpoint_callback_bool(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir)
with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"):
Trainer(checkpoint_callback=mc)