From 733cdbb9ad11dfaa20d5cfc6a85314c57a1ad2f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 13 Jul 2021 01:20:20 +0200 Subject: [PATCH] `every_n_val_epochs` -> `every_n_epochs` (#8383) --- CHANGELOG.md | 3 + pytorch_lightning/callbacks/early_stopping.py | 2 +- .../callbacks/model_checkpoint.py | 88 +++++++++++-------- .../test_checkpoint_callback_frequency.py | 1 + tests/checkpointing/test_model_checkpoint.py | 62 ++++++------- tests/deprecated_api/test_remove_1-6.py | 6 ++ tests/loggers/test_wandb.py | 1 - 7 files changed, 91 insertions(+), 72 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83c302677a..0b35c08bdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) +- Deprecated `ModelCheckpoint(every_n_val_epochs)` in favor of `ModelCheckpoint(every_n_epochs)` ([#8383](https://github.com/PyTorchLightning/pytorch-lightning/pull/8383)) + + - Deprecated `DDPPlugin.task_idx` in favor of `DDPPlugin.local_rank` ([#8203](https://github.com/PyTorchLightning/pytorch-lightning/pull/8203)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index c28e5cec5b..0015ac47f0 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -61,7 +61,7 @@ class EarlyStopping(Callback): stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. - If this is ``False``, then the check runs at the end of the validation epoch. + If this is ``False``, then the check runs at the end of the validation. Raises: MisconfigurationException: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 46778547ac..4e006d81ce 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -108,26 +108,33 @@ class ModelCheckpoint(Callback): every_n_train_steps: Number of training steps between checkpoints. If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. - This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``. + This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. train_time_interval: Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. - This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``. - every_n_val_epochs: Number of validation epochs between checkpoints. - If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end. - To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. + every_n_epochs: Number of epochs between checkpoints. + If ``every_n_epochs == None or every_n_epochs == 0``, we skip saving when the epoch ends. + To disable, set ``every_n_epochs = 0``. This value must be ``None`` or non-negative. This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. - Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and + Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N - where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E. + where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. period: Interval (number of epochs) between checkpoints. .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. - Use ``every_n_val_epochs`` instead. + Use ``every_n_epochs`` instead. + every_n_val_epochs: Number of epochs between checkpoints. + + .. warning:: + This argument has been deprecated in v1.4 and will be removed in v1.6. + + Use ``every_n_epochs`` instead. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -205,8 +212,9 @@ class ModelCheckpoint(Callback): auto_insert_metric_name: bool = True, every_n_train_steps: Optional[int] = None, train_time_interval: Optional[timedelta] = None, - every_n_val_epochs: Optional[int] = None, + every_n_epochs: Optional[int] = None, period: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -224,9 +232,16 @@ class ModelCheckpoint(Callback): self.best_model_path = "" self.last_model_path = "" + if every_n_val_epochs is not None: + rank_zero_deprecation( + '`ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6.' + ' Please use `every_n_epochs` instead.' + ) + every_n_epochs = every_n_val_epochs + self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) - self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period) + self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period) self.__validate_init_configuration() self._save_function = None @@ -274,11 +289,10 @@ class ModelCheckpoint(Callback): def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ Save a checkpoint at the end of the validation stage. """ - skip = ( - self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 - or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 - ) - if skip: + if ( + self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_epochs != 0 + ): return self.save_checkpoint(trainer) @@ -354,18 +368,16 @@ class ModelCheckpoint(Callback): raise MisconfigurationException( f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' ) - if self._every_n_val_epochs < 0: - raise MisconfigurationException( - f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' - ) + if self._every_n_epochs < 0: + raise MisconfigurationException(f'Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0') every_n_train_steps_triggered = self._every_n_train_steps >= 1 - every_n_val_epochs_triggered = self._every_n_val_epochs >= 1 + every_n_epochs_triggered = self._every_n_epochs >= 1 train_time_interval_triggered = self._train_time_interval is not None - if every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1: + if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: raise MisconfigurationException( f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " - f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} " + f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " "should be mutually exclusive." ) @@ -412,39 +424,41 @@ class ModelCheckpoint(Callback): self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], - train_time_interval: Optional[timedelta], period: Optional[int] + self, + every_n_train_steps: Optional[int], + every_n_epochs: Optional[int], + train_time_interval: Optional[timedelta], + period: Optional[int], ) -> None: # Default to running once after each validation epoch if neither - # every_n_train_steps nor every_n_val_epochs is set - if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None: - every_n_val_epochs = 1 + # every_n_train_steps nor every_n_epochs is set + if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None: + every_n_epochs = 1 every_n_train_steps = 0 - log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1") else: - every_n_val_epochs = every_n_val_epochs or 0 + every_n_epochs = every_n_epochs or 0 every_n_train_steps = every_n_train_steps or 0 self._train_time_interval: Optional[timedelta] = train_time_interval - self._every_n_val_epochs: int = every_n_val_epochs + self._every_n_epochs: int = every_n_epochs self._every_n_train_steps: int = every_n_train_steps - # period takes precedence over every_n_val_epochs for backwards compatibility + # period takes precedence over every_n_epochs for backwards compatibility if period is not None: rank_zero_deprecation( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_epochs` instead.' ) - self._every_n_val_epochs = period - - self._period = self._every_n_val_epochs + self._every_n_epochs = period + self._period = self._every_n_epochs @property def period(self) -> Optional[int]: rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_epochs` instead.' ) return self._period @@ -452,7 +466,7 @@ class ModelCheckpoint(Callback): def period(self, value: Optional[int]) -> None: rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_epochs` instead.' ) self._period = value diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 8617a9f8f7..9e85d0cb0c 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -60,6 +60,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter max_epochs=epochs, weights_summary=None, val_check_interval=val_check_interval, + limit_val_batches=1, progress_bar_refresh_rate=0, ) trainer.fit(model) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 2193d65fcd..a9e8f05785 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -558,50 +558,48 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) -def test_invalid_every_n_val_epochs(tmpdir): - """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ +def test_invalid_every_n_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """ with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): - ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-3) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) - ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) - ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) def test_invalid_every_n_train_steps(tmpdir): - """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + """ Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """ with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) - ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) def test_invalid_trigger_combination(tmpdir): """ Test that a MisconfigurationException is raised if more than one of - every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together. + every_n_epochs, every_n_train_steps, and train_time_interval are enabled together. """ with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2) with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): - ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2) + ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2) with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) - ModelCheckpoint( - dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1) - ) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1)) def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.period == 1 - assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_epochs == 1 assert checkpoint_callback._every_n_train_steps == 0 @@ -659,12 +657,12 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -677,22 +675,17 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) -def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): - """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): + """ Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - filename='{epoch}', - save_top_k=-1, - every_n_val_epochs=(2 * every_n_val_epochs), - period=every_n_val_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -705,8 +698,7 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -719,7 +711,7 @@ def test_ckpt_every_n_train_steps(tmpdir): epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", - every_n_val_epochs=0, + every_n_epochs=0, every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, @@ -892,6 +884,8 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + limit_train_batches=1, + limit_val_batches=1, ) with caplog.at_level(logging.INFO): trainer.fit(model) @@ -910,6 +904,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=num_epochs, + limit_train_batches=2, + limit_val_batches=2, ) trainer.fit(model) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 69d2a45530..ddb551631c 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,6 +15,7 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin @@ -303,3 +304,8 @@ def test_v1_6_0_deprecated_disable_validation(): trainer = Trainer() with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"): _ = trainer.disable_validation + + +def test_v1_6_0_every_n_val_epochs(): + with pytest.deprecated_call(match="use `every_n_epochs` instead"): + _ = ModelCheckpoint(every_n_val_epochs=1) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 27b83b75c2..4956a08c2f 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -213,7 +213,6 @@ def test_wandb_log_model(wandb, tmpdir): 'save_top_k': 1, 'save_weights_only': False, '_every_n_train_steps': 0, - '_every_n_val_epochs': 1 } } )