fix case where an IterableDataset doesn't produce a batch for an epoch (#7294)
* wip * fix * add test * refactor + test * rm * formatting * update changelog * doc * docstring * remove unused import * Update CHANGELOG.md Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
969e857690
commit
b9b3fa371f
|
@ -406,6 +406,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240))
|
||||
|
||||
|
||||
- Fixed handling an `IterableDataset` that fails to produce a batch at the beginning of an epoch ([#7294](https://github.com/PyTorchLightning/pytorch-lightning/pull/7294))
|
||||
|
||||
|
||||
- Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268))
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.trainer.supporters import prefetch_iterator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
|
@ -44,22 +45,10 @@ class DataConnector(object):
|
|||
|
||||
def get_profiled_train_dataloader(self, train_dataloader):
|
||||
profiled_dl = self.trainer.profiler.profile_iterable(
|
||||
enumerate(self._with_is_last(train_dataloader)), "get_train_batch"
|
||||
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
|
||||
)
|
||||
return profiled_dl
|
||||
|
||||
def _with_is_last(self, iterable):
|
||||
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
|
||||
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
|
||||
it = iter(iterable)
|
||||
last = next(it)
|
||||
for val in it:
|
||||
# yield last and has next
|
||||
yield last, False
|
||||
last = val
|
||||
# yield last, no longer has next
|
||||
yield last, True
|
||||
|
||||
def prepare_data(self, model):
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import os
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Generator, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -508,3 +508,25 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
|
|||
new_data.append(x)
|
||||
|
||||
return compute_func(new_data)
|
||||
|
||||
|
||||
def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
|
||||
"""
|
||||
Returns an iterator that pre-fetches and caches the next item.
|
||||
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
|
||||
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
|
||||
"""
|
||||
it = iter(iterable)
|
||||
|
||||
try:
|
||||
# the iterator may be empty from the beginning
|
||||
last = next(it)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
for val in it:
|
||||
# yield last and has next
|
||||
yield last, False
|
||||
last = val
|
||||
# yield last, no longer has next
|
||||
yield last, True
|
||||
|
|
|
@ -458,8 +458,10 @@ class TrainLoop:
|
|||
dataloader_idx = 0
|
||||
val_loop_called = False
|
||||
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
batch_idx = None
|
||||
is_last_batch = None
|
||||
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
self.trainer.batch_idx = batch_idx
|
||||
self.trainer.is_last_batch = is_last_batch
|
||||
|
||||
|
@ -530,6 +532,10 @@ class TrainLoop:
|
|||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
||||
if batch_idx is None:
|
||||
# dataloader/iterator did not produce a batch
|
||||
return
|
||||
|
||||
# handle epoch_output on epoch end
|
||||
self.on_train_epoch_end(epoch_output)
|
||||
|
||||
|
|
|
@ -779,6 +779,41 @@ def test_warning_with_iterable_dataset_and_len(tmpdir):
|
|||
trainer.predict(model, dataloaders=dataloader)
|
||||
|
||||
|
||||
def test_iterable_dataset_stop_iteration_at_epoch_beginning():
|
||||
""" Test that the training loop skips execution if the iterator is empty from the start. """
|
||||
|
||||
class RandomDataset(IterableDataset):
|
||||
|
||||
def __init__(self, gen):
|
||||
self.gen = gen
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.gen())
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(self.gen), batch_size=2)
|
||||
|
||||
def gen(self):
|
||||
# produce data in epoch 0
|
||||
# no data otherwise
|
||||
if self.current_epoch == 0:
|
||||
yield torch.rand(32)
|
||||
yield torch.rand(32)
|
||||
yield torch.rand(32)
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
max_epochs=2, # we expect the second epoch to be skipped
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == 2
|
||||
assert trainer.current_epoch == 1
|
||||
|
||||
|
||||
@RunIf(min_gpus=2)
|
||||
def test_dataloader_reinit_for_subclass(tmpdir):
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from unittest import mock
|
|||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
@ -29,6 +29,7 @@ from pytorch_lightning.trainer.supporters import (
|
|||
CombinedLoader,
|
||||
CombinedLoaderIterator,
|
||||
CycleIterator,
|
||||
prefetch_iterator,
|
||||
TensorRunningAccum,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -78,6 +79,30 @@ def test_none_length_cycle_iterator():
|
|||
assert item == 0
|
||||
|
||||
|
||||
def test_prefetch_iterator():
|
||||
""" Test the prefetch_iterator with PyTorch IterableDataset. """
|
||||
|
||||
class IterDataset(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
dataset = IterDataset()
|
||||
iterator = prefetch_iterator(dataset)
|
||||
assert [item for item in iterator] == [(1, False), (2, False), (3, True)]
|
||||
|
||||
class EmptyIterDataset(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
return iter([])
|
||||
|
||||
dataset = EmptyIterDataset()
|
||||
iterator = prefetch_iterator(dataset)
|
||||
assert [item for item in iterator] == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["dataset_1", "dataset_2"],
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue