# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pickle import platform import re from argparse import Namespace from pathlib import Path from unittest import mock from unittest.mock import Mock import cloudpickle import pytest import torch import yaml from omegaconf import Container, OmegaConf import pytorch_lightning as pl import tests.base.develop_utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel class LogInTwoMethods(BoringModel): def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) self.log('early_stop_on', out['loss']) return out def validation_epoch_end(self, outputs): outs = torch.stack([x['x'] for x in outputs]).mean() self.log('epoch', self.current_epoch, on_epoch=True) self.log('val_acc', outs, on_epoch=True) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize('save_top_k', [-1]) def test_model_checkpoint_correct_score(tmpdir, save_top_k): """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path""" tutils.reset_seed() model = LogInTwoMethods() filename = "{val_acc:.4f}-{epoch}" checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor='val_acc', save_top_k=save_top_k) trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) trainer.fit(model) ckpt_files = list(Path(tmpdir).glob('*.ckpt')) metrics = trainer.dev_debugger.logged_metrics expected_filenames = {f'val_acc={metric["val_acc"]:.4f}-epoch={metric["epoch"]}.ckpt' for metric in metrics} for ckpt_file in ckpt_files: assert os.path.basename(ckpt_file) in expected_filenames @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): """Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly""" tutils.reset_seed() model = LogInTwoMethods() checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}', save_top_k=save_top_k) max_epochs = 2 trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=max_epochs, ) trainer.fit(model) assert ( checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints" ) if save_top_k == -1: ckpt_files = os.listdir(checkpoint.dirpath) expected_ckpt_files = [f'epoch={i}.ckpt' for i in range(max_epochs)] assert len(ckpt_files) == len(expected_ckpt_files) == max_epochs assert set(ckpt_files) == set(expected_ckpt_files) @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_to_yaml(tmpdir, save_top_k): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() model = LogInTwoMethods() checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k) trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) trainer.fit(model) path_yaml = os.path.join(tmpdir, 'best_k_models.yaml') checkpoint.to_yaml(path_yaml) d = yaml.full_load(open(path_yaml, 'r')) best_k = {k: v for k, v in checkpoint.best_k_models.items()} assert d == best_k @pytest.mark.parametrize( "logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")], ) def test_model_checkpoint_path(tmpdir, logger_version, expected): """Test that "version_" prefix is only added when logger's version is an integer""" tutils.reset_seed() model = LogInTwoMethods() logger = TensorBoardLogger(str(tmpdir), version=logger_version) trainer = Trainer( default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger ) trainer.fit(model) ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name assert ckpt_version == expected def test_pickling(tmpdir): ckpt = ModelCheckpoint(dirpath=tmpdir) ckpt_pickled = pickle.dumps(ckpt) ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) class ModelCheckpointTestInvocations(ModelCheckpoint): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched def __init__(self, expected_count, *args, **kwargs): super().__init__(*args, **kwargs) self.expected_count = expected_count self.on_save_checkpoint_count = 0 def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def on_save_checkpoint(self, trainer, pl_module): # expect all ranks to run but only rank 0 will actually write the checkpoint file super().on_save_checkpoint(trainer, pl_module) self.on_save_checkpoint_count += 1 def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) assert self.best_model_path assert self.best_model_score assert self.on_save_checkpoint_count == self.expected_count if trainer.is_global_zero: # twice the calls expected because ddp broadcast also uses torch.save assert torch.save.call_count == self.expected_count * 2 else: assert torch.save.call_count == 0 @pytest.mark.skipif( platform.system() == "Windows", reason="Distributed training is not supported on Windows", ) def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = LogInTwoMethods() num_epochs = 4 model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) trainer = Trainer( accelerator="ddp_cpu", num_processes=2, default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=num_epochs, ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) assert ckpt_name == 'epoch=3-step=2' ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') assert ckpt_name == 'test-epoch=3-step=2' # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) assert ckpt_name == 'epoch=3-step=2.ckpt' ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with version ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test' ).format_checkpoint_name(3, 2, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' # using slashes ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}' ).format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' class ModelCheckpointExtensionTest(ModelCheckpoint): FILE_EXTENSION = '.tpkc' def test_model_checkpoint_file_extension(tmpdir): """ Test ModelCheckpoint with different file extension. """ model = LogInTwoMethods() model_checkpoint = ModelCheckpointExtensionTest( monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True, ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_steps=1, logger=False, ) trainer.fit(model) expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] assert set(expected) == set(os.listdir(tmpdir)) def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything() model = LogInTwoMethods() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=epochs, limit_train_batches=10, limit_val_batches=10, logger=False, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name( ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path assert set(os.listdir(tmpdir)) == set( [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] ) ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' def test_invalid_top_k(tmpdir): """ Make sure that a MisconfigurationException is raised for a negative save_top_k argument. """ with pytest.raises(MisconfigurationException, match=r'.*Must be None or >= -1'): ModelCheckpoint(dirpath=tmpdir, save_top_k=-3) def test_none_monitor_top_k(tmpdir): """ Test that a warning appears for positive top_k with monitor=None. """ with pytest.raises( MisconfigurationException, match=r'ModelCheckpoint\(save_top_k=3, monitor=None\) is not a valid*' ): ModelCheckpoint(dirpath=tmpdir, save_top_k=3) # These should not fail ModelCheckpoint(dirpath=tmpdir, save_top_k=None) ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ with pytest.warns( UserWarning, match=r'ModelCheckpoint\(save_last=True, monitor=None\) is a redundant.*' ): ModelCheckpoint(dirpath=tmpdir, save_last=True) # These should not fail ModelCheckpoint(dirpath=tmpdir, save_last=None) ModelCheckpoint(dirpath=tmpdir, save_last=False) def test_model_checkpoint_none_monitor(tmpdir): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() model = LogInTwoMethods() epochs = 2 checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], limit_train_batches=10, limit_val_batches=10, max_epochs=epochs, logger=False, ) trainer.fit(model) # these should not be set if monitor is None assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] assert set(os.listdir(tmpdir)) == set(expected) @pytest.mark.parametrize("period", list(range(4))) def test_model_checkpoint_period(tmpdir, period): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, limit_train_batches=0.1, limit_val_batches=0.1, val_check_interval=1.0, logger=False, ) trainer.fit(model) # check that the correct ckpts were created expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=2, logger=False, ) trainer.fit(model) # these should not be set if monitor is None assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == '' assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that no ckpts were created assert len(os.listdir(tmpdir)) == 0 def test_model_checkpoint_topk_all(tmpdir): """ Test that save_top_k=-1 tracks the best models when monitor key is provided. """ seed_everything(1000) epochs = 3 class CustomModel(LogInTwoMethods): def validation_epoch_end(self, outputs): return {'epoch': self.current_epoch} model = CustomModel() checkpoint_callback = ModelCheckpoint( dirpath=tmpdir, filename="{epoch}", monitor="epoch", mode='max', save_top_k=-1, ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, logger=False, val_check_interval=1.0, ) trainer.fit(model) assert checkpoint_callback.monitor == 'epoch' assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt" assert checkpoint_callback.best_model_score == epochs - 1 assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)) assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt' def test_ckpt_metric_names(tmpdir): model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, overfit_batches=0.20, progress_bar_refresh_rate=0, limit_train_batches=0.01, limit_val_batches=0.01, callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename="{val_loss:.2f}")], ) trainer.fit(model) # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(tmpdir) ckpts = [x for x in ckpts if "val_loss" in x] assert len(ckpts) == 1 val = re.sub("[^0-9.]", "", ckpts[0]) assert len(val) > 3 @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_default_checkpoint_behavior(tmpdir): seed_everything(1234) os.environ['PL_DEV_DEBUG'] = '1' model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5, ) trainer.fit(model) results = trainer.test() assert len(results) == 1 assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 assert ckpts[0] == 'epoch=2-step=14.ckpt' @pytest.mark.parametrize('max_epochs', [1, 2]) @pytest.mark.parametrize('should_validate', [True, False]) @pytest.mark.parametrize('save_last', [True, False]) def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last): """Tests 'Saving latest checkpoint...' log""" model = LogInTwoMethods() if not should_validate: model.validation_step = None trainer = Trainer( default_root_dir=tmpdir, callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)], max_epochs=max_epochs, ) trainer.fit(model) assert caplog.messages.count('Saving latest checkpoint...') == save_last def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """ Tests that the save_last checkpoint contains the latest information. """ seed_everything(100) model = LogInTwoMethods() num_epochs = 3 model_checkpoint = ModelCheckpoint( monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=num_epochs, ) trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) ch_type = type(model_checkpoint) assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) model_last = LogInTwoMethods.load_from_checkpoint( model_checkpoint.last_model_path ) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize('mode', ['min', 'max']) def test_checkpointing_with_nan_as_first(tmpdir, mode): monitor = [float('nan')] monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5] class CurrentModel(LogInTwoMethods): def validation_epoch_end(self, outputs): val_loss = monitor[self.current_epoch] self.log('abc', val_loss) model = CurrentModel() trainer = Trainer( callbacks=[ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, dirpath=tmpdir)], default_root_dir=tmpdir, val_check_interval=1.0, max_epochs=len(monitor), ) trainer.fit(model) # check that last one is also the best one assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1 @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_checkpoint_repeated_strategy(tmpdir): """ This test validates that the checkpoint can be called when provided to callbacks list """ checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} model = ExtendedBoringModel() model.validation_epoch_end = None trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint_callback], weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(model) assert os.listdir(tmpdir) == ['epoch=00.ckpt'] for idx in range(4): # load from checkpoint model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path) trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, resume_from_checkpoint=checkpoint_callback.best_model_path, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(model) trainer.test(model, verbose=False) assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)} @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_checkpoint_repeated_strategy_extended(tmpdir): """ This test validates checkpoint can be called several times without increasing internally its global step if nothing run. """ class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} def validation_epoch_end(self, *_): ... def assert_trainer_init(trainer): assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == 0 assert trainer.current_epoch == 0 def get_last_checkpoint(ckpt_dir): last = ckpt_dir.listdir(sort=True)[-1] return str(last) def assert_checkpoint_content(ckpt_dir): chk = pl_load(get_last_checkpoint(ckpt_dir)) assert chk["epoch"] == epochs assert chk["global_step"] == 4 def assert_checkpoint_log_dir(idx): lightning_logs = tmpdir / 'lightning_logs' actual = [d.basename for d in lightning_logs.listdir(sort=True)] assert actual == [f'version_{i}' for i in range(idx + 1)] assert len(ckpt_dir.listdir()) == epochs ckpt_dir = tmpdir / 'checkpoints' checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) epochs = 2 limit_train_batches = 2 trainer_config = dict( default_root_dir=tmpdir, max_epochs=epochs, limit_train_batches=limit_train_batches, limit_val_batches=3, limit_test_batches=4, callbacks=[checkpoint_cb], ) trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.fit(model) assert trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) trainer.test(model) assert trainer.current_epoch == epochs - 1 for idx in range(1, 5): chk = get_last_checkpoint(ckpt_dir) assert_checkpoint_content(ckpt_dir) # load from checkpoint trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.test(model) assert not trainer.checkpoint_connector.has_trained # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) def test_configure_model_checkpoint(tmpdir): """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ kwargs = dict(default_root_dir=tmpdir) callback1 = ModelCheckpoint() callback2 = ModelCheckpoint() # no callbacks trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs) assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks) assert trainer.checkpoint_callback is None # default configuration trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs) assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1 assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) # custom callback passed to callbacks list, checkpoint_callback=True is ignored trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs) assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] assert trainer.checkpoint_callback == callback1 # multiple checkpoint callbacks trainer = Trainer(callbacks=[callback1, callback2], **kwargs) assert trainer.checkpoint_callback == callback1 assert trainer.checkpoint_callbacks == [callback1, callback2] with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.3'): trainer = Trainer(checkpoint_callback=callback1, **kwargs) assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] assert trainer.checkpoint_callback == callback1 with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.3"): trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs) assert trainer.checkpoint_callback == callback2 assert trainer.checkpoint_callbacks == [callback2, callback1] with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) def test_val_check_interval_checkpoint_files(tmpdir): """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ model = LogInTwoMethods() model_checkpoint = ModelCheckpoint( dirpath=tmpdir, save_top_k=-1, monitor="val_acc", mode="max", ) trainer = Trainer( default_root_dir=tmpdir, val_check_interval=0.2, max_epochs=1, limit_train_batches=10, callbacks=[model_checkpoint], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(model) files = {p.basename for p in tmpdir.listdir()} assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]} def test_current_score(tmpdir): """ Check that the current_score value is correct and was saved """ class TestModel(BoringModel): def training_step(self, *args): self.log("foo", (self.current_epoch + 1) / 10) return super().training_step(*args) model_checkpoint = ModelCheckpoint( dirpath=tmpdir, save_top_k=3, monitor="foo", mode="min", ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] @pytest.mark.parametrize("mode", ["min", "max"]) def test_current_score_when_nan(tmpdir, mode): """ Check that ModelCheckpoint handles NaN values correctly """ class TestModel(BoringModel): def training_step(self, *args): self.log("foo", float("nan")) return super().training_step(*args) model_checkpoint = ModelCheckpoint( dirpath=tmpdir, save_top_k=1, monitor="foo", mode=mode, ) trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(TestModel()) expected = float("inf" if mode == "min" else "-inf") assert model_checkpoint.best_model_score == expected assert model_checkpoint.current_score == expected @pytest.mark.parametrize("hparams_type", [dict, Container]) def test_hparams_type(tmpdir, hparams_type): class TestModel(BoringModel): def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) model_checkpoint = ModelCheckpoint( dirpath=tmpdir, save_top_k=1, monitor="foo", ) trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, callbacks=[model_checkpoint], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) hp = {"test_hp_0": 1, "test_hp_1": 2} hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp) model = TestModel(hp) trainer.fit(model) ckpt = trainer.checkpoint_connector.dump_checkpoint() if hparams_type == Container: assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type) else: # make sure it's not AttributeDict assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type def test_ckpt_version_after_rerun_new_trainer(tmpdir): """ Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances are used """ epochs = 2 for i in range(epochs): mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}") trainer = Trainer( max_epochs=epochs, limit_train_batches=1, limit_val_batches=1, default_root_dir=tmpdir, callbacks=[mc], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(BoringModel()) # check best_k_models state expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"} assert {Path(f).name for f in mc.best_k_models.keys()} == expected # check created ckpts assert set(f.basename for f in tmpdir.listdir()) == { "epoch=0.ckpt", "epoch=1.ckpt", "epoch=0-v1.ckpt", "epoch=1-v1.ckpt", } def test_ckpt_version_after_rerun_same_trainer(tmpdir): """ Check that previous checkpoints are renamed to have the correct version suffix when the same trainer instance is used """ mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test") mc.STARTING_VERSION = 9 trainer = Trainer( max_epochs=2, limit_train_batches=1, limit_val_batches=1, default_root_dir=tmpdir, callbacks=[mc], logger=False, weights_summary=None, progress_bar_refresh_rate=0, ) trainer.fit(BoringModel()) trainer.max_epochs = 4 trainer.fit(BoringModel()) ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION) expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]} # check best_k_models state assert {Path(f).name for f in mc.best_k_models.keys()} == expected # check created ckpts assert set(sorted(os.listdir(tmpdir))) == expected def test_model_checkpoint_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): ModelCheckpoint(mode="unknown_option")