diff --git a/CHANGELOG.md b/CHANGELOG.md index e7b4480c59..0b024ab46c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -406,6 +406,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240)) +- Fixed handling an `IterableDataset` that fails to produce a batch at the beginning of an epoch ([#7294](https://github.com/PyTorchLightning/pytorch-lightning/pull/7294)) + - Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268)) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index fd6c9ea328..476bd3fc14 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -18,6 +18,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.trainer.supporters import prefetch_iterator from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -44,22 +45,10 @@ class DataConnector(object): def get_profiled_train_dataloader(self, train_dataloader): profiled_dl = self.trainer.profiler.profile_iterable( - enumerate(self._with_is_last(train_dataloader)), "get_train_batch" + enumerate(prefetch_iterator(train_dataloader)), "get_train_batch" ) return profiled_dl - def _with_is_last(self, iterable): - """Pass through values from the given iterable with an added boolean indicating if this is the last item. - See `https://stackoverflow.com/a/1630350 `_""" - it = iter(iterable) - last = next(it) - for val in it: - # yield last and has next - yield last, False - last = val - # yield last, no longer has next - yield last, True - def prepare_data(self, model): # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 3cb0b0cb1f..18a012da54 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -14,7 +14,7 @@ import os from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Generator, Optional, Tuple, Union import torch from torch import Tensor @@ -508,3 +508,25 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable new_data.append(x) return compute_func(new_data) + + +def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]: + """ + Returns an iterator that pre-fetches and caches the next item. + The values are passed through from the given iterable with an added boolean indicating if this is the last item. + See `https://stackoverflow.com/a/1630350 `_ + """ + it = iter(iterable) + + try: + # the iterator may be empty from the beginning + last = next(it) + except StopIteration: + return + + for val in it: + # yield last and has next + yield last, False + last = val + # yield last, no longer has next + yield last, True diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8a007086fb..397f5cc5cf 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -458,8 +458,10 @@ class TrainLoop: dataloader_idx = 0 val_loop_called = False - for batch_idx, (batch, is_last_batch) in train_dataloader: + batch_idx = None + is_last_batch = None + for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx self.trainer.is_last_batch = is_last_batch @@ -530,6 +532,10 @@ class TrainLoop: # progress global step according to grads progress self.increment_accumulated_grad_global_step() + if batch_idx is None: + # dataloader/iterator did not produce a batch + return + # handle epoch_output on epoch end self.on_train_epoch_end(epoch_output) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2a744c9c05..58fb157291 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -779,6 +779,41 @@ def test_warning_with_iterable_dataset_and_len(tmpdir): trainer.predict(model, dataloaders=dataloader) +def test_iterable_dataset_stop_iteration_at_epoch_beginning(): + """ Test that the training loop skips execution if the iterator is empty from the start. """ + + class RandomDataset(IterableDataset): + + def __init__(self, gen): + self.gen = gen + + def __iter__(self): + return iter(self.gen()) + + class TestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(RandomDataset(self.gen), batch_size=2) + + def gen(self): + # produce data in epoch 0 + # no data otherwise + if self.current_epoch == 0: + yield torch.rand(32) + yield torch.rand(32) + yield torch.rand(32) + + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + max_epochs=2, # we expect the second epoch to be skipped + weights_summary=None, + ) + trainer.fit(model) + assert trainer.global_step == 2 + assert trainer.current_epoch == 1 + + @RunIf(min_gpus=2) def test_dataloader_reinit_for_subclass(tmpdir): diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 6da2436b5e..6d6b1e9ad1 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -18,7 +18,7 @@ from unittest import mock import pytest import torch from torch.utils.data import DataLoader, TensorDataset -from torch.utils.data.dataset import Dataset +from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import Sampler @@ -29,6 +29,7 @@ from pytorch_lightning.trainer.supporters import ( CombinedLoader, CombinedLoaderIterator, CycleIterator, + prefetch_iterator, TensorRunningAccum, ) from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -78,6 +79,30 @@ def test_none_length_cycle_iterator(): assert item == 0 +def test_prefetch_iterator(): + """ Test the prefetch_iterator with PyTorch IterableDataset. """ + + class IterDataset(IterableDataset): + + def __iter__(self): + yield 1 + yield 2 + yield 3 + + dataset = IterDataset() + iterator = prefetch_iterator(dataset) + assert [item for item in iterator] == [(1, False), (2, False), (3, True)] + + class EmptyIterDataset(IterableDataset): + + def __iter__(self): + return iter([]) + + dataset = EmptyIterDataset() + iterator = prefetch_iterator(dataset) + assert [item for item in iterator] == [] + + @pytest.mark.parametrize( ["dataset_1", "dataset_2"], [