Throw MisconfigurationError on unknown mode (#5255)
* Throw MisconfigurationError on unknown mode * Add tests * Add match condition for deprecation message
This commit is contained in:
parent
059f4630c8
commit
f6dc354349
|
@ -25,6 +25,7 @@ import torch
|
|||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
|
@ -96,15 +97,12 @@ class EarlyStopping(Callback):
|
|||
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
|
||||
def __init_monitor_mode(self):
|
||||
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
|
||||
if self.mode not in self.mode_dict and self.mode != 'auto':
|
||||
if self.verbose > 0:
|
||||
rank_zero_warn(
|
||||
f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.',
|
||||
RuntimeWarning,
|
||||
)
|
||||
self.mode = 'auto'
|
||||
raise MisconfigurationException(
|
||||
f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}"
|
||||
)
|
||||
|
||||
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
|
||||
if self.mode == 'auto':
|
||||
rank_zero_warn(
|
||||
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
|
||||
|
|
|
@ -287,14 +287,12 @@ class ModelCheckpoint(Callback):
|
|||
"max": (-torch_inf, "max"),
|
||||
}
|
||||
|
||||
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
|
||||
if mode not in mode_dict and mode != 'auto':
|
||||
rank_zero_warn(
|
||||
f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode",
|
||||
RuntimeWarning,
|
||||
raise MisconfigurationException(
|
||||
f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}"
|
||||
)
|
||||
mode = "auto"
|
||||
|
||||
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
|
||||
if mode == 'auto':
|
||||
rank_zero_warn(
|
||||
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
|
||||
|
|
|
@ -300,3 +300,8 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
|
|||
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
|
||||
|
||||
_logger.disabled = False
|
||||
|
||||
|
||||
def test_early_stopping_mode_options():
|
||||
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
|
||||
EarlyStopping(mode="unknown_option")
|
||||
|
|
|
@ -947,3 +947,8 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex
|
|||
|
||||
epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
|
||||
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
|
||||
|
||||
|
||||
def test_model_checkpoint_mode_options():
|
||||
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
|
||||
ModelCheckpoint(mode="unknown_option")
|
||||
|
|
Loading…
Reference in New Issue