From f6dc354349094e3692fe25fea5ccb14439c318ff Mon Sep 17 00:00:00 2001 From: Alan Du Date: Tue, 12 Jan 2021 02:31:26 -0500 Subject: [PATCH] Throw MisconfigurationError on unknown mode (#5255) * Throw MisconfigurationError on unknown mode * Add tests * Add match condition for deprecation message --- pytorch_lightning/callbacks/early_stopping.py | 12 +++++------- pytorch_lightning/callbacks/model_checkpoint.py | 8 +++----- tests/callbacks/test_early_stopping.py | 5 +++++ tests/checkpointing/test_model_checkpoint.py | 5 +++++ 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index fca39036c9..ec44a1eeb4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -25,6 +25,7 @@ import torch from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class EarlyStopping(Callback): @@ -96,15 +97,12 @@ class EarlyStopping(Callback): self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf def __init_monitor_mode(self): - # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if self.mode not in self.mode_dict and self.mode != 'auto': - if self.verbose > 0: - rank_zero_warn( - f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.', - RuntimeWarning, - ) - self.mode = 'auto' + raise MisconfigurationException( + f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}" + ) + # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if self.mode == 'auto': rank_zero_warn( "mode='auto' is deprecated in v1.1 and will be removed in v1.3." diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3fc2b54d98..8a89cd2bef 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -287,14 +287,12 @@ class ModelCheckpoint(Callback): "max": (-torch_inf, "max"), } - # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if mode not in mode_dict and mode != 'auto': - rank_zero_warn( - f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode", - RuntimeWarning, + raise MisconfigurationException( + f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}" ) - mode = "auto" + # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if mode == 'auto': rank_zero_warn( "mode='auto' is deprecated in v1.1 and will be removed in v1.3." diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 4bb328cb88..925f296d0a 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -300,3 +300,8 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) _logger.disabled = False + + +def test_early_stopping_mode_options(): + with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): + EarlyStopping(mode="unknown_option") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 32a32d7527..58a202d573 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -947,3 +947,8 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files] assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs)) + + +def test_model_checkpoint_mode_options(): + with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): + ModelCheckpoint(mode="unknown_option")