diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f0f09401ab..689c2bff8e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -305,8 +305,6 @@ class DataFetcher(AbstractDataFetcher): def _get_queued_batch(self) -> Tuple[Any, bool]: self.wait() batch = self.batches.pop(0) - if not self.store_on_device: - batch = self.move_data_to_device(batch) is_last = len(self.batches) == 0 return batch, is_last diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 420f580260..fd3df1df0c 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -21,7 +21,7 @@ import torch from torch import tensor from torch.utils.data import DataLoader, Dataset, IterableDataset -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher @@ -392,3 +392,52 @@ def test_tbptt_split_batch_overridden(tmpdir) -> None: m = InvalidModel() with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."): trainer.fit(m) + + +def test_transfer_hooks_with_unpacking(tmpdir): + + """This test asserts the `transfer_batch` hooks are called only once per batch.""" + + class RandomDictDataset(RandomDataset): + def __getitem__(self, index): + return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))} + + class BoringDataModule(LightningDataModule): + + count_called_on_before_batch_transfer = 0 + count_called_transfer_batch_to_device = 0 + count_called_on_after_batch_transfer = 0 + + def train_dataloader(self): + return DataLoader(RandomDictDataset(32, 2)) + + def val_dataloader(self): + return DataLoader(RandomDictDataset(32, 2)) + + def on_before_batch_transfer(self, batch, dataloader_idx: int): + self.count_called_on_before_batch_transfer += 1 + return batch["x"], batch["y_true"] + + def transfer_batch_to_device(self, *args, **kwargs): + self.count_called_transfer_batch_to_device += 1 + return super().transfer_batch_to_device(*args, **kwargs) + + def on_after_batch_transfer(self, batch, dataloader_idx: int): + self.count_called_on_after_batch_transfer += 1 + return super().on_after_batch_transfer(batch, dataloader_idx) + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + x, _ = batch + return super().training_step(x, batch_idx) + + def validation_step(self, batch, batch_idx): + x, _ = batch + return super().validation_step(x, batch_idx) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0) + dm = BoringDataModule() + trainer.fit(TestModel(), datamodule=dm) + assert dm.count_called_on_before_batch_transfer == 4 + assert dm.count_called_transfer_batch_to_device == 4 + assert dm.count_called_on_after_batch_transfer == 4