`every_n_val_epochs` -> `every_n_epochs` (#8383)
This commit is contained in:
parent
f3e828426a
commit
733cdbb9ad
|
@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))
|
||||
|
||||
|
||||
- Deprecated `ModelCheckpoint(every_n_val_epochs)` in favor of `ModelCheckpoint(every_n_epochs)` ([#8383](https://github.com/PyTorchLightning/pytorch-lightning/pull/8383))
|
||||
|
||||
|
||||
- Deprecated `DDPPlugin.task_idx` in favor of `DDPPlugin.local_rank` ([#8203](https://github.com/PyTorchLightning/pytorch-lightning/pull/8203))
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class EarlyStopping(Callback):
|
|||
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
|
||||
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
|
||||
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
|
||||
If this is ``False``, then the check runs at the end of the validation epoch.
|
||||
If this is ``False``, then the check runs at the end of the validation.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
|
|
@ -108,26 +108,33 @@ class ModelCheckpoint(Callback):
|
|||
every_n_train_steps: Number of training steps between checkpoints.
|
||||
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
|
||||
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
|
||||
This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
|
||||
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
|
||||
train_time_interval: Checkpoints are monitored at the specified time interval.
|
||||
For all practical purposes, this cannot be smaller than the amount
|
||||
of time it takes to process a single training batch. This is not
|
||||
guaranteed to execute at the exact time specified, but should be close.
|
||||
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
|
||||
every_n_val_epochs: Number of validation epochs between checkpoints.
|
||||
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end.
|
||||
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
|
||||
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
||||
every_n_epochs: Number of epochs between checkpoints.
|
||||
If ``every_n_epochs == None or every_n_epochs == 0``, we skip saving when the epoch ends.
|
||||
To disable, set ``every_n_epochs = 0``. This value must be ``None`` or non-negative.
|
||||
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
|
||||
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
|
||||
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
|
||||
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
|
||||
will only save checkpoints at epochs 0 < E <= N
|
||||
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
||||
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
|
||||
period: Interval (number of epochs) between checkpoints.
|
||||
|
||||
.. warning::
|
||||
This argument has been deprecated in v1.3 and will be removed in v1.5.
|
||||
|
||||
Use ``every_n_val_epochs`` instead.
|
||||
Use ``every_n_epochs`` instead.
|
||||
every_n_val_epochs: Number of epochs between checkpoints.
|
||||
|
||||
.. warning::
|
||||
This argument has been deprecated in v1.4 and will be removed in v1.6.
|
||||
|
||||
Use ``every_n_epochs`` instead.
|
||||
|
||||
|
||||
Note:
|
||||
For extra customization, ModelCheckpoint includes the following attributes:
|
||||
|
@ -205,8 +212,9 @@ class ModelCheckpoint(Callback):
|
|||
auto_insert_metric_name: bool = True,
|
||||
every_n_train_steps: Optional[int] = None,
|
||||
train_time_interval: Optional[timedelta] = None,
|
||||
every_n_val_epochs: Optional[int] = None,
|
||||
every_n_epochs: Optional[int] = None,
|
||||
period: Optional[int] = None,
|
||||
every_n_val_epochs: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.monitor = monitor
|
||||
|
@ -224,9 +232,16 @@ class ModelCheckpoint(Callback):
|
|||
self.best_model_path = ""
|
||||
self.last_model_path = ""
|
||||
|
||||
if every_n_val_epochs is not None:
|
||||
rank_zero_deprecation(
|
||||
'`ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6.'
|
||||
' Please use `every_n_epochs` instead.'
|
||||
)
|
||||
every_n_epochs = every_n_val_epochs
|
||||
|
||||
self.__init_monitor_mode(mode)
|
||||
self.__init_ckpt_dir(dirpath, filename)
|
||||
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
|
||||
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
|
||||
self.__validate_init_configuration()
|
||||
self._save_function = None
|
||||
|
||||
|
@ -274,11 +289,10 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
|
||||
""" Save a checkpoint at the end of the validation stage. """
|
||||
skip = (
|
||||
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
|
||||
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
|
||||
)
|
||||
if skip:
|
||||
if (
|
||||
self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1
|
||||
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
|
||||
):
|
||||
return
|
||||
self.save_checkpoint(trainer)
|
||||
|
||||
|
@ -354,18 +368,16 @@ class ModelCheckpoint(Callback):
|
|||
raise MisconfigurationException(
|
||||
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
|
||||
)
|
||||
if self._every_n_val_epochs < 0:
|
||||
raise MisconfigurationException(
|
||||
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
|
||||
)
|
||||
if self._every_n_epochs < 0:
|
||||
raise MisconfigurationException(f'Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0')
|
||||
|
||||
every_n_train_steps_triggered = self._every_n_train_steps >= 1
|
||||
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
|
||||
every_n_epochs_triggered = self._every_n_epochs >= 1
|
||||
train_time_interval_triggered = self._train_time_interval is not None
|
||||
if every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1:
|
||||
if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
|
||||
raise MisconfigurationException(
|
||||
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
|
||||
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
|
||||
f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
|
||||
"should be mutually exclusive."
|
||||
)
|
||||
|
||||
|
@ -412,39 +424,41 @@ class ModelCheckpoint(Callback):
|
|||
self.kth_value, self.mode = mode_dict[mode]
|
||||
|
||||
def __init_triggers(
|
||||
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int],
|
||||
train_time_interval: Optional[timedelta], period: Optional[int]
|
||||
self,
|
||||
every_n_train_steps: Optional[int],
|
||||
every_n_epochs: Optional[int],
|
||||
train_time_interval: Optional[timedelta],
|
||||
period: Optional[int],
|
||||
) -> None:
|
||||
|
||||
# Default to running once after each validation epoch if neither
|
||||
# every_n_train_steps nor every_n_val_epochs is set
|
||||
if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None:
|
||||
every_n_val_epochs = 1
|
||||
# every_n_train_steps nor every_n_epochs is set
|
||||
if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
|
||||
every_n_epochs = 1
|
||||
every_n_train_steps = 0
|
||||
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
|
||||
log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
|
||||
else:
|
||||
every_n_val_epochs = every_n_val_epochs or 0
|
||||
every_n_epochs = every_n_epochs or 0
|
||||
every_n_train_steps = every_n_train_steps or 0
|
||||
|
||||
self._train_time_interval: Optional[timedelta] = train_time_interval
|
||||
self._every_n_val_epochs: int = every_n_val_epochs
|
||||
self._every_n_epochs: int = every_n_epochs
|
||||
self._every_n_train_steps: int = every_n_train_steps
|
||||
|
||||
# period takes precedence over every_n_val_epochs for backwards compatibility
|
||||
# period takes precedence over every_n_epochs for backwards compatibility
|
||||
if period is not None:
|
||||
rank_zero_deprecation(
|
||||
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
|
||||
' Please use `every_n_val_epochs` instead.'
|
||||
' Please use `every_n_epochs` instead.'
|
||||
)
|
||||
self._every_n_val_epochs = period
|
||||
|
||||
self._period = self._every_n_val_epochs
|
||||
self._every_n_epochs = period
|
||||
self._period = self._every_n_epochs
|
||||
|
||||
@property
|
||||
def period(self) -> Optional[int]:
|
||||
rank_zero_deprecation(
|
||||
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
|
||||
' Please use `every_n_val_epochs` instead.'
|
||||
' Please use `every_n_epochs` instead.'
|
||||
)
|
||||
return self._period
|
||||
|
||||
|
@ -452,7 +466,7 @@ class ModelCheckpoint(Callback):
|
|||
def period(self, value: Optional[int]) -> None:
|
||||
rank_zero_deprecation(
|
||||
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
|
||||
' Please use `every_n_val_epochs` instead.'
|
||||
' Please use `every_n_epochs` instead.'
|
||||
)
|
||||
self._period = value
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
|
|||
max_epochs=epochs,
|
||||
weights_summary=None,
|
||||
val_check_interval=val_check_interval,
|
||||
limit_val_batches=1,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -558,50 +558,48 @@ def test_none_monitor_save_last(tmpdir):
|
|||
ModelCheckpoint(dirpath=tmpdir, save_last=False)
|
||||
|
||||
|
||||
def test_invalid_every_n_val_epochs(tmpdir):
|
||||
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
|
||||
def test_invalid_every_n_epochs(tmpdir):
|
||||
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
|
||||
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-3)
|
||||
# These should not fail
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=1)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)
|
||||
|
||||
|
||||
def test_invalid_every_n_train_steps(tmpdir):
|
||||
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
|
||||
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
|
||||
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
|
||||
# These should not fail
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)
|
||||
|
||||
|
||||
def test_invalid_trigger_combination(tmpdir):
|
||||
"""
|
||||
Test that a MisconfigurationException is raised if more than one of
|
||||
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
|
||||
every_n_epochs, every_n_train_steps, and train_time_interval are enabled together.
|
||||
"""
|
||||
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2)
|
||||
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
|
||||
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
|
||||
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2)
|
||||
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
|
||||
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)
|
||||
|
||||
# These should not fail
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
|
||||
ModelCheckpoint(
|
||||
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
|
||||
)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0)
|
||||
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1))
|
||||
|
||||
|
||||
def test_none_every_n_train_steps_val_epochs(tmpdir):
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
|
||||
assert checkpoint_callback.period == 1
|
||||
assert checkpoint_callback._every_n_val_epochs == 1
|
||||
assert checkpoint_callback._every_n_epochs == 1
|
||||
assert checkpoint_callback._every_n_train_steps == 0
|
||||
|
||||
|
||||
|
@ -659,12 +657,12 @@ def test_model_checkpoint_period(tmpdir, period: int):
|
|||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
|
||||
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
|
||||
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
|
||||
def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
|
||||
model = LogInTwoMethods()
|
||||
epochs = 5
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
|
||||
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -677,22 +675,17 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
|
|||
trainer.fit(model)
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs)
|
||||
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
|
||||
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
|
||||
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
|
||||
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
|
||||
def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs):
|
||||
""" Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """
|
||||
model = LogInTwoMethods()
|
||||
epochs = 5
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=tmpdir,
|
||||
filename='{epoch}',
|
||||
save_top_k=-1,
|
||||
every_n_val_epochs=(2 * every_n_val_epochs),
|
||||
period=every_n_val_epochs
|
||||
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -705,8 +698,7 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
|
|||
trainer.fit(model)
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs)
|
||||
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
|
@ -719,7 +711,7 @@ def test_ckpt_every_n_train_steps(tmpdir):
|
|||
epoch_length = 64
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filename="{step}",
|
||||
every_n_val_epochs=0,
|
||||
every_n_epochs=0,
|
||||
every_n_train_steps=every_n_train_steps,
|
||||
dirpath=tmpdir,
|
||||
save_top_k=-1,
|
||||
|
@ -892,6 +884,8 @@ def test_model_checkpoint_save_last_warning(
|
|||
default_root_dir=tmpdir,
|
||||
callbacks=[ckpt],
|
||||
max_epochs=max_epochs,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
trainer.fit(model)
|
||||
|
@ -910,6 +904,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
callbacks=[model_checkpoint],
|
||||
max_epochs=num_epochs,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
|
||||
|
@ -303,3 +304,8 @@ def test_v1_6_0_deprecated_disable_validation():
|
|||
trainer = Trainer()
|
||||
with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"):
|
||||
_ = trainer.disable_validation
|
||||
|
||||
|
||||
def test_v1_6_0_every_n_val_epochs():
|
||||
with pytest.deprecated_call(match="use `every_n_epochs` instead"):
|
||||
_ = ModelCheckpoint(every_n_val_epochs=1)
|
||||
|
|
|
@ -213,7 +213,6 @@ def test_wandb_log_model(wandb, tmpdir):
|
|||
'save_top_k': 1,
|
||||
'save_weights_only': False,
|
||||
'_every_n_train_steps': 0,
|
||||
'_every_n_val_epochs': 1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue