Support no pre-fetching (#11606)
This commit is contained in:
parent
c71a1d7ea2
commit
7da931d1ca
|
@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))
|
||||
|
||||
|
||||
- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606))
|
||||
|
||||
|
||||
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))
|
||||
|
||||
|
||||
|
|
|
@ -201,19 +201,15 @@ def _no_op_batch_to_device(batch: Any) -> Any:
|
|||
|
||||
|
||||
class DataFetcher(AbstractDataFetcher):
|
||||
|
||||
"""This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a
|
||||
batch in advance to detect the end of the iteration.
|
||||
"""This class is used to control batch fetching flow.
|
||||
|
||||
Args:
|
||||
prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch
|
||||
at least 1 batch for tracking the latest batch.
|
||||
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
|
||||
whether a batch is the last one (available with :attr:`self.done`).
|
||||
store_on_device: Whether to store the pre-fetched batches on device.
|
||||
"""
|
||||
|
||||
def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None:
|
||||
if prefetch_batches < 1:
|
||||
raise MisconfigurationException("`prefetch_batches` should at least be 1.")
|
||||
super().__init__(prefetch_batches=prefetch_batches)
|
||||
self.store_on_device = store_on_device
|
||||
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
|
||||
|
@ -240,19 +236,31 @@ class DataFetcher(AbstractDataFetcher):
|
|||
break
|
||||
|
||||
def fetching_function(self) -> Tuple[Any, bool]:
|
||||
assert self.dataloader_iter is not None
|
||||
if self.batches:
|
||||
# there are pre-fetched batches already from a previous `prefetching` call.
|
||||
# consume one
|
||||
batch = self.batches.pop(0)
|
||||
else:
|
||||
# empty iterator, no prefetching done
|
||||
raise StopIteration
|
||||
if not self.done:
|
||||
assert self.dataloader_iter is not None
|
||||
try:
|
||||
# refill the consumed batch
|
||||
self._fetch_next_batch(self.dataloader_iter)
|
||||
except StopIteration:
|
||||
# no more batches to fetch. we are done only if all pre-fetched batches were returned
|
||||
self.done = not self.batches
|
||||
elif not self.done:
|
||||
# this will run only when no pre-fetching was done.
|
||||
try:
|
||||
self._fetch_next_batch(self.dataloader_iter)
|
||||
# consume the batch we just fetched
|
||||
batch = self.batches.pop(0)
|
||||
except StopIteration as e:
|
||||
self.done = True
|
||||
raise e
|
||||
else:
|
||||
# the iterator is empty
|
||||
raise StopIteration
|
||||
self.wait()
|
||||
return self.move_to_device(batch), len(self.batches) == 0
|
||||
return self.move_to_device(batch), self.done
|
||||
|
||||
def _fetch_next_batch(self, iterator: Iterator) -> None:
|
||||
start_output = self.on_fetch_start()
|
||||
|
|
|
@ -40,28 +40,32 @@ def test_prefetch_iterator(use_combined_loader):
|
|||
yield 2
|
||||
yield 3
|
||||
|
||||
for prefetch_batches in range(1, 5):
|
||||
if use_combined_loader:
|
||||
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
|
||||
expected = [
|
||||
([tensor([1]), tensor([1])], False),
|
||||
([tensor([2]), tensor([2])], False),
|
||||
([tensor([3]), tensor([3])], True),
|
||||
]
|
||||
else:
|
||||
loader = DataLoader(IterDataset())
|
||||
expected = [(1, False), (2, False), (3, True)]
|
||||
for prefetch_batches in range(5):
|
||||
iterator = DataFetcher(prefetch_batches=prefetch_batches)
|
||||
assert iterator.prefetch_batches == prefetch_batches
|
||||
|
||||
if use_combined_loader:
|
||||
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
|
||||
else:
|
||||
loader = DataLoader(IterDataset())
|
||||
iterator.setup(loader)
|
||||
|
||||
def generate():
|
||||
generated = []
|
||||
for idx, data in enumerate(iterator, prefetch_batches + 1):
|
||||
assert iterator.fetched == 3 if iterator.done else idx
|
||||
generated.append(data)
|
||||
generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)]
|
||||
assert iterator.fetched == 3
|
||||
assert iterator.done
|
||||
return generated
|
||||
|
||||
is_last_batch = [False, False, prefetch_batches > 0]
|
||||
fetched = list(range(prefetch_batches + 1, 4))
|
||||
fetched += [3] * (3 - len(fetched))
|
||||
if use_combined_loader:
|
||||
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
|
||||
else:
|
||||
batches = [1, 2, 3]
|
||||
expected = list(zip(fetched, batches, is_last_batch))
|
||||
assert len(expected) == 3
|
||||
|
||||
assert generate() == expected
|
||||
# validate reset works properly.
|
||||
assert generate() == expected
|
||||
|
@ -71,9 +75,9 @@ def test_prefetch_iterator(use_combined_loader):
|
|||
def __iter__(self):
|
||||
return iter([])
|
||||
|
||||
dataloader = DataLoader(EmptyIterDataset())
|
||||
loader = DataLoader(EmptyIterDataset())
|
||||
iterator = DataFetcher()
|
||||
iterator.setup(dataloader)
|
||||
iterator.setup(loader)
|
||||
assert not list(iterator)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue