`every_n_val_epochs` -> `every_n_epochs` (#8383)

This commit is contained in:
Carlos Mocholí 2021-07-13 01:20:20 +02:00 committed by GitHub
parent f3e828426a
commit 733cdbb9ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 72 deletions

View File

@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))
- Deprecated `ModelCheckpoint(every_n_val_epochs)` in favor of `ModelCheckpoint(every_n_epochs)` ([#8383](https://github.com/PyTorchLightning/pytorch-lightning/pull/8383))
- Deprecated `DDPPlugin.task_idx` in favor of `DDPPlugin.local_rank` ([#8203](https://github.com/PyTorchLightning/pytorch-lightning/pull/8203))

View File

@ -61,7 +61,7 @@ class EarlyStopping(Callback):
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation epoch.
If this is ``False``, then the check runs at the end of the validation.
Raises:
MisconfigurationException:

View File

@ -108,26 +108,33 @@ class ModelCheckpoint(Callback):
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
train_time_interval: Checkpoints are monitored at the specified time interval.
For all practical purposes, this cannot be smaller than the amount
of time it takes to process a single training batch. This is not
guaranteed to execute at the exact time specified, but should be close.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end.
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
every_n_epochs: Number of epochs between checkpoints.
If ``every_n_epochs == None or every_n_epochs == 0``, we skip saving when the epoch ends.
To disable, set ``every_n_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
period: Interval (number of epochs) between checkpoints.
.. warning::
This argument has been deprecated in v1.3 and will be removed in v1.5.
Use ``every_n_val_epochs`` instead.
Use ``every_n_epochs`` instead.
every_n_val_epochs: Number of epochs between checkpoints.
.. warning::
This argument has been deprecated in v1.4 and will be removed in v1.6.
Use ``every_n_epochs`` instead.
Note:
For extra customization, ModelCheckpoint includes the following attributes:
@ -205,8 +212,9 @@ class ModelCheckpoint(Callback):
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_val_epochs: Optional[int] = None,
every_n_epochs: Optional[int] = None,
period: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
@ -224,9 +232,16 @@ class ModelCheckpoint(Callback):
self.best_model_path = ""
self.last_model_path = ""
if every_n_val_epochs is not None:
rank_zero_deprecation(
'`ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6.'
' Please use `every_n_epochs` instead.'
)
every_n_epochs = every_n_val_epochs
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
self.__validate_init_configuration()
self._save_function = None
@ -274,11 +289,10 @@ class ModelCheckpoint(Callback):
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
""" Save a checkpoint at the end of the validation stage. """
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
)
if skip:
if (
self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
):
return
self.save_checkpoint(trainer)
@ -354,18 +368,16 @@ class ModelCheckpoint(Callback):
raise MisconfigurationException(
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
)
if self._every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
)
if self._every_n_epochs < 0:
raise MisconfigurationException(f'Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0')
every_n_train_steps_triggered = self._every_n_train_steps >= 1
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
every_n_epochs_triggered = self._every_n_epochs >= 1
train_time_interval_triggered = self._train_time_interval is not None
if every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1:
if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
raise MisconfigurationException(
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
"should be mutually exclusive."
)
@ -412,39 +424,41 @@ class ModelCheckpoint(Callback):
self.kth_value, self.mode = mode_dict[mode]
def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int],
train_time_interval: Optional[timedelta], period: Optional[int]
self,
every_n_train_steps: Optional[int],
every_n_epochs: Optional[int],
train_time_interval: Optional[timedelta],
period: Optional[int],
) -> None:
# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None:
every_n_val_epochs = 1
# every_n_train_steps nor every_n_epochs is set
if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
every_n_epochs = 1
every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
else:
every_n_val_epochs = every_n_val_epochs or 0
every_n_epochs = every_n_epochs or 0
every_n_train_steps = every_n_train_steps or 0
self._train_time_interval: Optional[timedelta] = train_time_interval
self._every_n_val_epochs: int = every_n_val_epochs
self._every_n_epochs: int = every_n_epochs
self._every_n_train_steps: int = every_n_train_steps
# period takes precedence over every_n_val_epochs for backwards compatibility
# period takes precedence over every_n_epochs for backwards compatibility
if period is not None:
rank_zero_deprecation(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
self._every_n_val_epochs = period
self._period = self._every_n_val_epochs
self._every_n_epochs = period
self._period = self._every_n_epochs
@property
def period(self) -> Optional[int]:
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
return self._period
@ -452,7 +466,7 @@ class ModelCheckpoint(Callback):
def period(self, value: Optional[int]) -> None:
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
self._period = value

View File

@ -60,6 +60,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval,
limit_val_batches=1,
progress_bar_refresh_rate=0,
)
trainer.fit(model)

View File

@ -558,50 +558,48 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_last=False)
def test_invalid_every_n_val_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
def test_invalid_every_n_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)
def test_invalid_every_n_train_steps(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)
def test_invalid_trigger_combination(tmpdir):
"""
Test that a MisconfigurationException is raised if more than one of
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
every_n_epochs, every_n_train_steps, and train_time_interval are enabled together.
"""
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
ModelCheckpoint(
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1))
def test_none_every_n_train_steps_val_epochs(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
assert checkpoint_callback.period == 1
assert checkpoint_callback._every_n_val_epochs == 1
assert checkpoint_callback._every_n_epochs == 1
assert checkpoint_callback._every_n_train_steps == 0
@ -659,12 +657,12 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
@ -677,22 +675,17 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
trainer.fit(model)
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs):
""" Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
@ -705,8 +698,7 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
trainer.fit(model)
# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)
@ -719,7 +711,7 @@ def test_ckpt_every_n_train_steps(tmpdir):
epoch_length = 64
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
every_n_epochs=0,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
@ -892,6 +884,8 @@ def test_model_checkpoint_save_last_warning(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
limit_train_batches=1,
limit_val_batches=1,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
@ -910,6 +904,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=num_epochs,
limit_train_batches=2,
limit_val_batches=2,
)
trainer.fit(model)

View File

@ -15,6 +15,7 @@
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
@ -303,3 +304,8 @@ def test_v1_6_0_deprecated_disable_validation():
trainer = Trainer()
with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"):
_ = trainer.disable_validation
def test_v1_6_0_every_n_val_epochs():
with pytest.deprecated_call(match="use `every_n_epochs` instead"):
_ = ModelCheckpoint(every_n_val_epochs=1)

View File

@ -213,7 +213,6 @@ def test_wandb_log_model(wandb, tmpdir):
'save_top_k': 1,
'save_weights_only': False,
'_every_n_train_steps': 0,
'_every_n_val_epochs': 1
}
}
)