[bugfix] Prevent on_before_batch_transfer to be called twice (#9715)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-09-27 20:36:56 +01:00 committed by GitHub
parent 64bbebc869
commit 131176b9f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 3 deletions

View File

@ -305,8 +305,6 @@ class DataFetcher(AbstractDataFetcher):
def _get_queued_batch(self) -> Tuple[Any, bool]:
self.wait()
batch = self.batches.pop(0)
if not self.store_on_device:
batch = self.move_data_to_device(batch)
is_last = len(self.batches) == 0
return batch, is_last

View File

@ -21,7 +21,7 @@ import torch
from torch import tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
@ -392,3 +392,52 @@ def test_tbptt_split_batch_overridden(tmpdir) -> None:
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
trainer.fit(m)
def test_transfer_hooks_with_unpacking(tmpdir):
"""This test asserts the `transfer_batch` hooks are called only once per batch."""
class RandomDictDataset(RandomDataset):
def __getitem__(self, index):
return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))}
class BoringDataModule(LightningDataModule):
count_called_on_before_batch_transfer = 0
count_called_transfer_batch_to_device = 0
count_called_on_after_batch_transfer = 0
def train_dataloader(self):
return DataLoader(RandomDictDataset(32, 2))
def val_dataloader(self):
return DataLoader(RandomDictDataset(32, 2))
def on_before_batch_transfer(self, batch, dataloader_idx: int):
self.count_called_on_before_batch_transfer += 1
return batch["x"], batch["y_true"]
def transfer_batch_to_device(self, *args, **kwargs):
self.count_called_transfer_batch_to_device += 1
return super().transfer_batch_to_device(*args, **kwargs)
def on_after_batch_transfer(self, batch, dataloader_idx: int):
self.count_called_on_after_batch_transfer += 1
return super().on_after_batch_transfer(batch, dataloader_idx)
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
x, _ = batch
return super().training_step(x, batch_idx)
def validation_step(self, batch, batch_idx):
x, _ = batch
return super().validation_step(x, batch_idx)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0)
dm = BoringDataModule()
trainer.fit(TestModel(), datamodule=dm)
assert dm.count_called_on_before_batch_transfer == 4
assert dm.count_called_transfer_batch_to_device == 4
assert dm.count_called_on_after_batch_transfer == 4