Add check for verbose attribute of ModelCheckpoint (#6419)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Łukasz Zalewski 2021-03-09 00:24:29 +01:00 committed by GitHub
parent e1f5eacab9
commit 9eded7fd73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 4 deletions

View File

@ -154,7 +154,7 @@ class TrainLoop:
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks
if is_last and any(cb.save_last for cb in callbacks):
if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")
model = self.trainer.lightning_module

View File

@ -672,19 +672,24 @@ def test_default_checkpoint_behavior(tmpdir):
@pytest.mark.parametrize('max_epochs', [1, 2])
@pytest.mark.parametrize('should_validate', [True, False])
@pytest.mark.parametrize('save_last', [True, False])
def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last):
@pytest.mark.parametrize('verbose', [True, False])
def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_validate, save_last, verbose):
"""Tests 'Saving latest checkpoint...' log"""
model = LogInTwoMethods()
if not should_validate:
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)],
callbacks=[
ModelCheckpoint(
monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose
)
],
max_epochs=max_epochs,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == save_last
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):