[bugfix] Prevent on_before_batch_transfer to be called twice (#9715)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
64bbebc869
commit
131176b9f5
|
@ -305,8 +305,6 @@ class DataFetcher(AbstractDataFetcher):
|
||||||
def _get_queued_batch(self) -> Tuple[Any, bool]:
|
def _get_queued_batch(self) -> Tuple[Any, bool]:
|
||||||
self.wait()
|
self.wait()
|
||||||
batch = self.batches.pop(0)
|
batch = self.batches.pop(0)
|
||||||
if not self.store_on_device:
|
|
||||||
batch = self.move_data_to_device(batch)
|
|
||||||
is_last = len(self.batches) == 0
|
is_last = len(self.batches) == 0
|
||||||
return batch, is_last
|
return batch, is_last
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
||||||
from torch import tensor
|
from torch import tensor
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from pytorch_lightning import Callback, Trainer
|
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
||||||
|
@ -392,3 +392,52 @@ def test_tbptt_split_batch_overridden(tmpdir) -> None:
|
||||||
m = InvalidModel()
|
m = InvalidModel()
|
||||||
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
|
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
|
||||||
trainer.fit(m)
|
trainer.fit(m)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transfer_hooks_with_unpacking(tmpdir):
|
||||||
|
|
||||||
|
"""This test asserts the `transfer_batch` hooks are called only once per batch."""
|
||||||
|
|
||||||
|
class RandomDictDataset(RandomDataset):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))}
|
||||||
|
|
||||||
|
class BoringDataModule(LightningDataModule):
|
||||||
|
|
||||||
|
count_called_on_before_batch_transfer = 0
|
||||||
|
count_called_transfer_batch_to_device = 0
|
||||||
|
count_called_on_after_batch_transfer = 0
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(RandomDictDataset(32, 2))
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(RandomDictDataset(32, 2))
|
||||||
|
|
||||||
|
def on_before_batch_transfer(self, batch, dataloader_idx: int):
|
||||||
|
self.count_called_on_before_batch_transfer += 1
|
||||||
|
return batch["x"], batch["y_true"]
|
||||||
|
|
||||||
|
def transfer_batch_to_device(self, *args, **kwargs):
|
||||||
|
self.count_called_transfer_batch_to_device += 1
|
||||||
|
return super().transfer_batch_to_device(*args, **kwargs)
|
||||||
|
|
||||||
|
def on_after_batch_transfer(self, batch, dataloader_idx: int):
|
||||||
|
self.count_called_on_after_batch_transfer += 1
|
||||||
|
return super().on_after_batch_transfer(batch, dataloader_idx)
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, _ = batch
|
||||||
|
return super().training_step(x, batch_idx)
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x, _ = batch
|
||||||
|
return super().validation_step(x, batch_idx)
|
||||||
|
|
||||||
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0)
|
||||||
|
dm = BoringDataModule()
|
||||||
|
trainer.fit(TestModel(), datamodule=dm)
|
||||||
|
assert dm.count_called_on_before_batch_transfer == 4
|
||||||
|
assert dm.count_called_transfer_batch_to_device == 4
|
||||||
|
assert dm.count_called_on_after_batch_transfer == 4
|
||||||
|
|
Loading…
Reference in New Issue