Support no pre-fetching (#11606)

This commit is contained in:
Carlos Mocholí 2022-02-05 04:59:46 +01:00 committed by GitHub
parent c71a1d7ea2
commit 7da931d1ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 30 deletions

View File

@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))
- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606))
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))

View File

@ -201,19 +201,15 @@ def _no_op_batch_to_device(batch: Any) -> Any:
class DataFetcher(AbstractDataFetcher):
"""This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a
batch in advance to detect the end of the iteration.
"""This class is used to control batch fetching flow.
Args:
prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch
at least 1 batch for tracking the latest batch.
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`).
store_on_device: Whether to store the pre-fetched batches on device.
"""
def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None:
if prefetch_batches < 1:
raise MisconfigurationException("`prefetch_batches` should at least be 1.")
super().__init__(prefetch_batches=prefetch_batches)
self.store_on_device = store_on_device
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
@ -240,19 +236,31 @@ class DataFetcher(AbstractDataFetcher):
break
def fetching_function(self) -> Tuple[Any, bool]:
assert self.dataloader_iter is not None
if self.batches:
# there are pre-fetched batches already from a previous `prefetching` call.
# consume one
batch = self.batches.pop(0)
else:
# empty iterator, no prefetching done
raise StopIteration
if not self.done:
assert self.dataloader_iter is not None
try:
# refill the consumed batch
self._fetch_next_batch(self.dataloader_iter)
except StopIteration:
# no more batches to fetch. we are done only if all pre-fetched batches were returned
self.done = not self.batches
elif not self.done:
# this will run only when no pre-fetching was done.
try:
self._fetch_next_batch(self.dataloader_iter)
# consume the batch we just fetched
batch = self.batches.pop(0)
except StopIteration as e:
self.done = True
raise e
else:
# the iterator is empty
raise StopIteration
self.wait()
return self.move_to_device(batch), len(self.batches) == 0
return self.move_to_device(batch), self.done
def _fetch_next_batch(self, iterator: Iterator) -> None:
start_output = self.on_fetch_start()

View File

@ -40,28 +40,32 @@ def test_prefetch_iterator(use_combined_loader):
yield 2
yield 3
for prefetch_batches in range(1, 5):
if use_combined_loader:
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
expected = [
([tensor([1]), tensor([1])], False),
([tensor([2]), tensor([2])], False),
([tensor([3]), tensor([3])], True),
]
else:
loader = DataLoader(IterDataset())
expected = [(1, False), (2, False), (3, True)]
for prefetch_batches in range(5):
iterator = DataFetcher(prefetch_batches=prefetch_batches)
assert iterator.prefetch_batches == prefetch_batches
if use_combined_loader:
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
else:
loader = DataLoader(IterDataset())
iterator.setup(loader)
def generate():
generated = []
for idx, data in enumerate(iterator, prefetch_batches + 1):
assert iterator.fetched == 3 if iterator.done else idx
generated.append(data)
generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)]
assert iterator.fetched == 3
assert iterator.done
return generated
is_last_batch = [False, False, prefetch_batches > 0]
fetched = list(range(prefetch_batches + 1, 4))
fetched += [3] * (3 - len(fetched))
if use_combined_loader:
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
else:
batches = [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3
assert generate() == expected
# validate reset works properly.
assert generate() == expected
@ -71,9 +75,9 @@ def test_prefetch_iterator(use_combined_loader):
def __iter__(self):
return iter([])
dataloader = DataLoader(EmptyIterDataset())
loader = DataLoader(EmptyIterDataset())
iterator = DataFetcher()
iterator.setup(dataloader)
iterator.setup(loader)
assert not list(iterator)