From 9eded7fd7323ef474923f14b9a016f4a0f0d8ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Zalewski?= Date: Tue, 9 Mar 2021 00:24:29 +0100 Subject: [PATCH] Add check for verbose attribute of ModelCheckpoint (#6419) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/training_loop.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7ad035e1f5..88b87afcb9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -154,7 +154,7 @@ class TrainLoop: if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks - if is_last and any(cb.save_last for cb in callbacks): + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.lightning_module diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3d5cddc453..1b33123d6d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -672,19 +672,24 @@ def test_default_checkpoint_behavior(tmpdir): @pytest.mark.parametrize('max_epochs', [1, 2]) @pytest.mark.parametrize('should_validate', [True, False]) @pytest.mark.parametrize('save_last', [True, False]) -def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last): +@pytest.mark.parametrize('verbose', [True, False]) +def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last, verbose): """Tests 'Saving latest checkpoint...' log""" model = LogInTwoMethods() if not should_validate: model.validation_step = None trainer = Trainer( default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)], + callbacks=[ + ModelCheckpoint( + monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose + ) + ], max_epochs=max_epochs, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == save_last + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):