Fix val_loop run on restart (#11552)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
a44881cd90
commit
76175217e4
|
@ -460,6 +460,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
|
||||
|
||||
|
||||
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
|
||||
|
||||
|
||||
## [1.5.8] - 2022-01-05
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -479,6 +479,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||
is_val_check_batch = is_last_batch
|
||||
|
||||
# while restarting with no fault-tolerant, batch_progress.current.ready is -1
|
||||
if batch_idx == -1:
|
||||
return False
|
||||
|
||||
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
|
||||
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
|
||||
elif self.trainer.val_check_batch != float("inf"):
|
||||
|
|
|
@ -11,9 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.loops import TrainingEpochLoop
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
_out00 = {"loss": 0.0}
|
||||
_out01 = {"loss": 0.1}
|
||||
|
@ -141,3 +145,28 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
|
|||
num_optimizers=-1, # does not matter for manual optimization
|
||||
)
|
||||
assert prepared == expected
|
||||
|
||||
|
||||
def test_no_val_on_train_epoch_loop_restart(tmpdir):
|
||||
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
|
||||
trainer_kwargs = {
|
||||
"max_epochs": 1,
|
||||
"limit_train_batches": 1,
|
||||
"limit_val_batches": 1,
|
||||
"num_sanity_val_steps": 0,
|
||||
"enable_checkpointing": False,
|
||||
}
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
model = BoringModel()
|
||||
trainer.fit(model)
|
||||
ckpt_path = str(tmpdir / "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
trainer_kwargs["max_epochs"] = 2
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
|
||||
with patch.object(
|
||||
trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance
|
||||
) as advance_mocked:
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
assert advance_mocked.call_count == 1
|
||||
|
|
Loading…
Reference in New Issue