[bugfix] Prevent on_before_batch_transfer to be called twice (#9715)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
64bbebc869
commit
131176b9f5
|
@ -305,8 +305,6 @@ class DataFetcher(AbstractDataFetcher):
|
|||
def _get_queued_batch(self) -> Tuple[Any, bool]:
|
||||
self.wait()
|
||||
batch = self.batches.pop(0)
|
||||
if not self.store_on_device:
|
||||
batch = self.move_data_to_device(batch)
|
||||
is_last = len(self.batches) == 0
|
||||
return batch, is_last
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
|||
from torch import tensor
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
||||
|
@ -392,3 +392,52 @@ def test_tbptt_split_batch_overridden(tmpdir) -> None:
|
|||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_transfer_hooks_with_unpacking(tmpdir):
|
||||
|
||||
"""This test asserts the `transfer_batch` hooks are called only once per batch."""
|
||||
|
||||
class RandomDictDataset(RandomDataset):
|
||||
def __getitem__(self, index):
|
||||
return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))}
|
||||
|
||||
class BoringDataModule(LightningDataModule):
|
||||
|
||||
count_called_on_before_batch_transfer = 0
|
||||
count_called_transfer_batch_to_device = 0
|
||||
count_called_on_after_batch_transfer = 0
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDictDataset(32, 2))
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(RandomDictDataset(32, 2))
|
||||
|
||||
def on_before_batch_transfer(self, batch, dataloader_idx: int):
|
||||
self.count_called_on_before_batch_transfer += 1
|
||||
return batch["x"], batch["y_true"]
|
||||
|
||||
def transfer_batch_to_device(self, *args, **kwargs):
|
||||
self.count_called_transfer_batch_to_device += 1
|
||||
return super().transfer_batch_to_device(*args, **kwargs)
|
||||
|
||||
def on_after_batch_transfer(self, batch, dataloader_idx: int):
|
||||
self.count_called_on_after_batch_transfer += 1
|
||||
return super().on_after_batch_transfer(batch, dataloader_idx)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, _ = batch
|
||||
return super().training_step(x, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, _ = batch
|
||||
return super().validation_step(x, batch_idx)
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0)
|
||||
dm = BoringDataModule()
|
||||
trainer.fit(TestModel(), datamodule=dm)
|
||||
assert dm.count_called_on_before_batch_transfer == 4
|
||||
assert dm.count_called_transfer_batch_to_device == 4
|
||||
assert dm.count_called_on_after_batch_transfer == 4
|
||||
|
|
Loading…
Reference in New Issue