diff --git a/CHANGELOG.md b/CHANGELOG.md index f75dffd925..36ab9dd9e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -460,6 +460,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) +- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552)) + + ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 20b8f6ae47..7438c25298 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -479,6 +479,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch + + # while restarting with no fault-tolerant, batch_progress.current.ready is -1 + if batch_idx == -1: + return False + if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): diff --git a/tests/loops/epoch/test_training_epoch_loop.py b/tests/loops/epoch/test_training_epoch_loop.py index 084031a0fb..6159809bce 100644 --- a/tests/loops/epoch/test_training_epoch_loop.py +++ b/tests/loops/epoch/test_training_epoch_loop.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest from pytorch_lightning.loops import TrainingEpochLoop +from pytorch_lightning.trainer.trainer import Trainer +from tests.helpers.boring_model import BoringModel _out00 = {"loss": 0.0} _out01 = {"loss": 0.1} @@ -141,3 +145,28 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected): num_optimizers=-1, # does not matter for manual optimization ) assert prepared == expected + + +def test_no_val_on_train_epoch_loop_restart(tmpdir): + """Test that training validation loop doesn't get triggered at the beginning of a restart.""" + trainer_kwargs = { + "max_epochs": 1, + "limit_train_batches": 1, + "limit_val_batches": 1, + "num_sanity_val_steps": 0, + "enable_checkpointing": False, + } + trainer = Trainer(**trainer_kwargs) + model = BoringModel() + trainer.fit(model) + ckpt_path = str(tmpdir / "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + trainer_kwargs["max_epochs"] = 2 + trainer = Trainer(**trainer_kwargs) + + with patch.object( + trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance + ) as advance_mocked: + trainer.fit(model, ckpt_path=ckpt_path) + assert advance_mocked.call_count == 1