Fix val_loop run on restart (#11552)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2022-02-03 01:49:34 +05:30 committed by GitHub
parent a44881cd90
commit 76175217e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 0 deletions

View File

@ -460,6 +460,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
## [1.5.8] - 2022-01-05
### Fixed

View File

@ -479,6 +479,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch
# while restarting with no fault-tolerant, batch_progress.current.ready is -1
if batch_idx == -1:
return False
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):

View File

@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import pytest
from pytorch_lightning.loops import TrainingEpochLoop
from pytorch_lightning.trainer.trainer import Trainer
from tests.helpers.boring_model import BoringModel
_out00 = {"loss": 0.0}
_out01 = {"loss": 0.1}
@ -141,3 +145,28 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
num_optimizers=-1, # does not matter for manual optimization
)
assert prepared == expected
def test_no_val_on_train_epoch_loop_restart(tmpdir):
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
trainer_kwargs = {
"max_epochs": 1,
"limit_train_batches": 1,
"limit_val_batches": 1,
"num_sanity_val_steps": 0,
"enable_checkpointing": False,
}
trainer = Trainer(**trainer_kwargs)
model = BoringModel()
trainer.fit(model)
ckpt_path = str(tmpdir / "last.ckpt")
trainer.save_checkpoint(ckpt_path)
trainer_kwargs["max_epochs"] = 2
trainer = Trainer(**trainer_kwargs)
with patch.object(
trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance
) as advance_mocked:
trainer.fit(model, ckpt_path=ckpt_path)
assert advance_mocked.call_count == 1