[fix] Attach train+val dataloaders to trainer in trainer loop (#7207)
* Update training_loop.py * Update test_dataloaders.py * changelog * delay reload * go back * comments * Update training_loop.py * Update test_dataloaders.py * Update tests/trainer/test_dataloaders.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
80b9ca0e38
commit
e407edba36
|
@ -290,6 +290,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
|
||||
- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207))
|
||||
|
||||
|
||||
- Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814))
|
||||
|
||||
|
||||
|
|
|
@ -174,11 +174,17 @@ class TrainLoop:
|
|||
# reset batch logger internals
|
||||
self.trainer.logger_connector.on_train_batch_end()
|
||||
|
||||
def reset_train_val_dataloaders(self, model):
|
||||
if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch:
|
||||
def reset_train_val_dataloaders(self, model) -> None:
|
||||
"""
|
||||
Resets train and val dataloaders if none are attached to the trainer.
|
||||
|
||||
The val dataloader must be initialized before training loop starts, as the training loop
|
||||
inspects the val dataloader to determine whether to run the evaluation loop.
|
||||
"""
|
||||
if self.trainer.train_dataloader is None:
|
||||
self.trainer.reset_train_dataloader(model)
|
||||
|
||||
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
|
||||
if self.trainer.val_dataloaders is None:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch.utils.data.sampler import SequentialSampler
|
|||
|
||||
import tests.helpers.pipelines as tpipes
|
||||
from pytorch_lightning import Callback, seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
|
||||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
||||
|
@ -1234,7 +1235,16 @@ def test_dataloaders_load_every_epoch(tmpdir):
|
|||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):
|
||||
|
||||
model = EvalModelTemplate()
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.log("dummy_val", 5.0)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
# This callback tests that the evaluation metrics are available by the time we run checkpointing
|
||||
checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1)
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
|
@ -1244,21 +1254,32 @@ def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):
|
|||
num_sanity_val_steps=0,
|
||||
reload_dataloaders_every_epoch=True,
|
||||
max_epochs=3,
|
||||
callbacks=[checkpoint_callback],
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
trainer.test()
|
||||
|
||||
assert len(trainer.dev_debugger.val_dataloader_calls) == 3
|
||||
assert len(trainer.dev_debugger.val_dataloader_calls) == 4
|
||||
assert len(trainer.dev_debugger.train_dataloader_calls) == 3
|
||||
assert len(trainer.dev_debugger.test_dataloader_calls) == 1
|
||||
|
||||
# verify the sequence
|
||||
calls = trainer.dev_debugger.dataloader_sequence_calls
|
||||
|
||||
expected_sequence = [
|
||||
'train_dataloader',
|
||||
'val_dataloader',
|
||||
# This has subsequent calls to val_dataloader
|
||||
# because the training loop runs the evaluation loop,
|
||||
# which reloads the val dataloader again.
|
||||
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
|
||||
# the val dataloader on the first epoch because this only tracks the training epoch
|
||||
# meaning multiple passes through the validation data within a single training epoch
|
||||
# would not have the dataloader reloaded.
|
||||
# This breaks the assumption behind reload_dataloaders_every_epoch=True
|
||||
'val_dataloader',
|
||||
'train_dataloader',
|
||||
'val_dataloader',
|
||||
'train_dataloader',
|
||||
|
|
Loading…
Reference in New Issue