Throw MisconfigurationError on unknown mode (#5255)

* Throw MisconfigurationError on unknown mode

* Add tests

* Add match condition for deprecation message
This commit is contained in:
Alan Du 2021-01-12 02:31:26 -05:00 committed by GitHub
parent 059f4630c8
commit f6dc354349
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 12 deletions

View File

@ -25,6 +25,7 @@ import torch
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class EarlyStopping(Callback):
@ -96,15 +97,12 @@ class EarlyStopping(Callback):
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
def __init_monitor_mode(self):
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if self.mode not in self.mode_dict and self.mode != 'auto':
if self.verbose > 0:
rank_zero_warn(
f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.',
RuntimeWarning,
)
self.mode = 'auto'
raise MisconfigurationException(
f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}"
)
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if self.mode == 'auto':
rank_zero_warn(
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."

View File

@ -287,14 +287,12 @@ class ModelCheckpoint(Callback):
"max": (-torch_inf, "max"),
}
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if mode not in mode_dict and mode != 'auto':
rank_zero_warn(
f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode",
RuntimeWarning,
raise MisconfigurationException(
f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}"
)
mode = "auto"
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if mode == 'auto':
rank_zero_warn(
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."

View File

@ -300,3 +300,8 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
_logger.disabled = False
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
EarlyStopping(mode="unknown_option")

View File

@ -947,3 +947,8 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex
epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
def test_model_checkpoint_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
ModelCheckpoint(mode="unknown_option")