Do not prefetch when possible (#12101)

This commit is contained in:
Carlos Mocholí 2022-02-28 19:31:18 +01:00 committed by GitHub
parent ed7ccca5df
commit 6309a59c3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 134 additions and 97 deletions

View File

@ -393,6 +393,9 @@ option when using sequential data.
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
Here ``mode`` can be train/val/test/predict.
When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect
when the training will stop and run validation if necessary.
.. testcode::
# IterableDataset

View File

@ -81,6 +81,13 @@ class EvaluationLoop(DataLoaderLoop):
raise RuntimeError("Dataloaders should be available.")
return dataloaders
@property
def prefetch_batches(self) -> int:
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
is_unsized = batches[self.current_dataloader_idx] == float("inf")
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
return 1 if is_unsized or inter_batch_parallelism else 0
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop
@ -121,7 +128,7 @@ class EvaluationLoop(DataLoaderLoop):
void(*args, **kwargs)
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
self._data_fetcher = data_fetcher_cls()
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
# hook
self._on_evaluation_model_eval()

View File

@ -85,6 +85,8 @@ class EvaluationEpochLoop(Loop):
self._reload_dataloader_state_dict(data_fetcher)
# creates the iterator inside the fetcher but returns `self`
self._data_fetcher = cast(AbstractDataFetcher, iter(data_fetcher))
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
def advance( # type: ignore[override]
self,

View File

@ -142,7 +142,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
self._reload_dataloader_state_dict(data_fetcher)
iter(data_fetcher) # creates the iterator inside the fetcher
_ = iter(data_fetcher) # creates the iterator inside the fetcher
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
"""Runs a single training batch.

View File

@ -149,6 +149,12 @@ class FitLoop(Loop[None]):
restarting &= finished_before_on_train_end
Loop.restarting.fset(self, restarting) # call the parent setter
@property
def prefetch_batches(self) -> int:
is_unsized = self.trainer.num_training_batches == float("inf")
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
return 1 if is_unsized or inter_batch_parallelism else 0
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
@ -213,8 +219,9 @@ class FitLoop(Loop[None]):
"""Calls the ``on_train_start`` hook."""
# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls()
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)

View File

@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader.
Raises:
ValueError:
If the length of Dataloader is 0, as it requires at least one batch
"""
infinite dataloader."""
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
has_len = True
except TypeError:
has_len = False
@ -122,30 +118,27 @@ def has_len_all_ranks(
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader.
Raises:
ValueError:
If the length of Dataloader is 0, as it requires at least one batch
"""
infinite dataloader."""
try:
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
local_length = len(dataloader)
total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")
if total_length == 0:
raise MisconfigurationException(
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
rank_zero_warn(
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
" Please make sure this was your intention."
)
if total_length > 0 and local_length == 0:
if model.allow_zero_length_dataloader_with_multiple_devices:
rank_zero_warn(
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
" Please be cautious of uneven batch length."
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero"
" length. Please be cautious of uneven batch length."
)
has_len = False
else:
raise MisconfigurationException(
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
f"`{dataloader.__class__.__name__}` within local rank has zero length."
" Please make sure that it returns at least 1 batch."
)
else:
has_len = True

View File

@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from copy import deepcopy
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Sized, Tuple
import torch
from torch.utils.data.dataloader import DataLoader
@ -30,6 +30,7 @@ from pytorch_lightning.utilities.auto_restart import (
MergedIteratorState,
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@ -79,6 +80,8 @@ class AbstractDataFetcher(ABC):
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
self._add_capture_metadata_collate(dataloader)
self._dataloader = dataloader
_patch_dataloader_get_iterators()
self._attach_data_fetcher()
@property
def dataloader(self) -> Iterable:
@ -172,8 +175,6 @@ class AbstractDataFetcher(ABC):
def __iter__(self) -> "AbstractDataFetcher":
self.reset()
self._attach_data_fetcher()
_patch_dataloader_get_iterators()
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching()
@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher):
Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`).
whether a batch is the last one (available with :attr:`self.done`) under any training setup.
store_on_device: Whether to store the pre-fetched batches on device.
"""
@ -214,11 +215,13 @@ class DataFetcher(AbstractDataFetcher):
self.store_on_device = store_on_device
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
self.batches: List[Any] = []
self._has_len = False
def setup( # type: ignore[override]
self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None
) -> None:
super().setup(dataloader)
self._has_len = has_len(dataloader)
if batch_to_device is not None:
self.batch_to_device = batch_to_device
@ -233,6 +236,9 @@ class DataFetcher(AbstractDataFetcher):
try:
self._fetch_next_batch(iterator)
except StopIteration:
# this would only happen when prefetch_batches > the number of batches available and makes
# `fetching_function` jump directly to the empty iterator case without trying to fetch again
self.done = True
break
def fetching_function(self) -> Any:
@ -266,6 +272,11 @@ class DataFetcher(AbstractDataFetcher):
start_output = self.on_fetch_start()
batch = next(iterator)
self.fetched += 1
if not self.prefetch_batches and self._has_len:
# when we don't prefetch but the dataloader is sized, we use the length for `done`
dataloader = self.dataloader
assert isinstance(dataloader, Sized) # `_has_len` is True
self.done = self.fetched >= len(dataloader)
self.on_fetch_end(batch, start_output)
def move_to_device(self, batch: Any) -> Any:
@ -360,7 +371,8 @@ class DataLoaderIterDataFetcher(AbstractDataFetcher):
...
"""
def __init__(self) -> None:
def __init__(self, prefetch_batches: int = 0) -> None:
# prefetch batches is not used for this class
super().__init__()
self.store_on_device = False

View File

@ -648,16 +648,12 @@ def test_loop_state_on_complete_run(n_optimizers, tmpdir):
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
"current": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
},
@ -956,8 +952,6 @@ def test_fit_can_fail_during_validation(train_datasets, val_datasets, val_check_
# totals are increased by 1 (the failed batch which never completed)
expected = state_dict.copy()
# TODO: `is_last_batch` is not correct on reload, the next line should not be necessary
expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0
assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"]
val_dl_progress = "epoch_loop.val_loop.dataloader_progress"

View File

@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
assert len(trainer.test_dataloaders) == 1
def test_error_on_zero_len_dataloader(tmpdir):
"""Test that error is raised if a zero-length dataloader is defined."""
class CustomBoringModel(BoringModel):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 0))
model = CustomBoringModel()
def test_warning_on_zero_len_dataloader(tmpdir):
"""Test that a warning is raised if a zero-length dataloader is defined."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
)
with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"):
trainer.fit(model)
dataloader = DataLoader(RandomDataset(32, 0))
with pytest.warns(UserWarning, match="returned 0 length"):
trainer.fit(model, dataloader)
@RunIf(skip_windows=True)

View File

@ -1452,7 +1452,7 @@ class RandomFaultTolerantDataset(RandomGetItemDataset):
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
def __init__(self, *args, seed: int = 0, **kwargs):
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
@ -1558,7 +1558,7 @@ def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_d
seed_everything(42)
model = TestModel(should_fail=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
with suppress(CustomException):
with pytest.raises(CustomException):
trainer.fit(model)
trainer.train_dataloader = None
failed_batches = model.batches

View File

@ -93,7 +93,7 @@ def test_has_iterable_dataset():
def test_has_len():
assert has_len(DataLoader(RandomDataset(1, 1)))
with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
assert has_len(DataLoader(RandomDataset(0, 0)))
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
@ -112,8 +112,8 @@ def test_has_len_all_rank():
trainer = Trainer(fast_dev_run=True)
model = BoringModel()
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model)

View File

@ -18,7 +18,6 @@ from unittest import mock
import pytest
import torch
from torch import tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning import Callback, LightningDataModule, Trainer
@ -30,57 +29,74 @@ from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3
class SizedDataset(Dataset):
def __len__(self):
return 3
def __getitem__(self, idx):
return idx + 1
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_prefetch_iterator(use_combined_loader):
"""Test the DataFetcher with PyTorch IterableDataset."""
@pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset])
@pytest.mark.parametrize("prefetch_batches", list(range(5)))
def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches):
fetcher = DataFetcher(prefetch_batches=prefetch_batches)
assert fetcher.prefetch_batches == prefetch_batches
class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3
if use_combined_loader:
loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())])
else:
loader = DataLoader(dataset_cls())
fetcher.setup(loader)
for prefetch_batches in range(5):
iterator = DataFetcher(prefetch_batches=prefetch_batches)
assert iterator.prefetch_batches == prefetch_batches
def generate():
generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher]
assert fetcher.fetched == 3
assert fetcher.done
return generated
if use_combined_loader:
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
else:
loader = DataLoader(IterDataset())
iterator.setup(loader)
# we can only know the last batch with sized iterables or when we prefetch
is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset]
fetched = list(range(prefetch_batches + 1, 4))
fetched += [3] * (3 - len(fetched))
batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3
def generate():
generated = [
(iterator.fetched, data, iterator.done) for i, data in enumerate(iterator, prefetch_batches + 1)
]
assert iterator.fetched == 3
assert iterator.done
return generated
assert generate() == expected
# validate reset works properly.
assert generate() == expected
assert fetcher.fetched == 3
is_last_batch = [False, False, prefetch_batches > 0]
fetched = list(range(prefetch_batches + 1, 4))
fetched += [3] * (3 - len(fetched))
if use_combined_loader:
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
else:
batches = [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3
assert generate() == expected
# validate reset works properly.
assert generate() == expected
assert iterator.fetched == 3
class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])
class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])
loader = DataLoader(EmptyIterDataset())
iterator = DataFetcher()
iterator.setup(loader)
assert not list(iterator)
class EmptySizedDataset(Dataset):
def __len__(self):
return 0
@pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset])
@pytest.mark.parametrize("prefetch_batches", list(range(2)))
def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):
loader = DataLoader(dataset_cls())
fetcher = DataFetcher(prefetch_batches=prefetch_batches)
fetcher.setup(loader)
assert not fetcher.done
assert not list(fetcher)
assert fetcher.done
def test_misconfiguration_error():
@ -188,7 +204,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
def on_train_epoch_end(self, trainer, lightning_module):
fetcher = trainer.fit_loop._data_fetcher
assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher)
assert fetcher.prefetch_batches == 1
assert fetcher.prefetch_batches == int(self._check_inter_batch)
trainer_kwargs = dict(
default_root_dir=tmpdir,
@ -269,14 +285,19 @@ def test_fetching_dataloader_iter_opt(automatic_optimization, tmpdir):
@RunIf(min_torch="1.8.0")
def test_fetching_dataloader_iter_running_stages(fn, tmpdir):
class TestModel(BoringModel):
def validation_step(self, dataloader_iter, batch_idx):
assert isinstance(self.trainer.validate_loop._data_fetcher, DataLoaderIterDataFetcher)
def fetch(self, data_fetcher, dataloader_iter, batch_idx):
assert isinstance(data_fetcher, DataLoaderIterDataFetcher)
assert data_fetcher.fetched == batch_idx
batch = next(dataloader_iter)
assert data_fetcher.fetched == batch_idx + 1
return batch
def validation_step(self, dataloader_iter, batch_idx):
batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx)
return super().validation_step(batch, batch_idx)
def test_step(self, dataloader_iter, batch_idx):
assert isinstance(self.trainer.test_loop._data_fetcher, DataLoaderIterDataFetcher)
batch = next(dataloader_iter)
batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx)
return super().test_step(batch, batch_idx)
model = TestModel()