fix case where an IterableDataset doesn't produce a batch for an epoch (#7294)

* wip

* fix

* add test

* refactor + test

* rm

* formatting

* update changelog

* doc

* docstring

* remove unused import

* Update CHANGELOG.md

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-04-30 14:45:55 +02:00 committed by GitHub
parent 969e857690
commit b9b3fa371f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 16 deletions

View File

@ -406,6 +406,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240))
- Fixed handling an `IterableDataset` that fails to produce a batch at the beginning of an epoch ([#7294](https://github.com/PyTorchLightning/pytorch-lightning/pull/7294))
- Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268))

View File

@ -18,6 +18,7 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -44,22 +45,10 @@ class DataConnector(object):
def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)), "get_train_batch"
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
return profiled_dl
def _with_is_last(self, iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0

View File

@ -14,7 +14,7 @@
import os
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Generator, Optional, Tuple, Union
import torch
from torch import Tensor
@ -508,3 +508,25 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
new_data.append(x)
return compute_func(new_data)
def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""
it = iter(iterable)
try:
# the iterator may be empty from the beginning
last = next(it)
except StopIteration:
return
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True

View File

@ -458,8 +458,10 @@ class TrainLoop:
dataloader_idx = 0
val_loop_called = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
batch_idx = None
is_last_batch = None
for batch_idx, (batch, is_last_batch) in train_dataloader:
self.trainer.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch
@ -530,6 +532,10 @@ class TrainLoop:
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
if batch_idx is None:
# dataloader/iterator did not produce a batch
return
# handle epoch_output on epoch end
self.on_train_epoch_end(epoch_output)

View File

@ -779,6 +779,41 @@ def test_warning_with_iterable_dataset_and_len(tmpdir):
trainer.predict(model, dataloaders=dataloader)
def test_iterable_dataset_stop_iteration_at_epoch_beginning():
""" Test that the training loop skips execution if the iterator is empty from the start. """
class RandomDataset(IterableDataset):
def __init__(self, gen):
self.gen = gen
def __iter__(self):
return iter(self.gen())
class TestModel(BoringModel):
def train_dataloader(self):
return DataLoader(RandomDataset(self.gen), batch_size=2)
def gen(self):
# produce data in epoch 0
# no data otherwise
if self.current_epoch == 0:
yield torch.rand(32)
yield torch.rand(32)
yield torch.rand(32)
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2, # we expect the second epoch to be skipped
weights_summary=None,
)
trainer.fit(model)
assert trainer.global_step == 2
assert trainer.current_epoch == 1
@RunIf(min_gpus=2)
def test_dataloader_reinit_for_subclass(tmpdir):

View File

@ -18,7 +18,7 @@ from unittest import mock
import pytest
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler
@ -29,6 +29,7 @@ from pytorch_lightning.trainer.supporters import (
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
prefetch_iterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -78,6 +79,30 @@ def test_none_length_cycle_iterator():
assert item == 0
def test_prefetch_iterator():
""" Test the prefetch_iterator with PyTorch IterableDataset. """
class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3
dataset = IterDataset()
iterator = prefetch_iterator(dataset)
assert [item for item in iterator] == [(1, False), (2, False), (3, True)]
class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])
dataset = EmptyIterDataset()
iterator = prefetch_iterator(dataset)
assert [item for item in iterator] == []
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[