3/n integrate new LightningDataFetcher into loop (#8953)

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-08-17 23:42:22 +02:00 committed by GitHub
parent 77bc5d4004
commit 522df2b89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 85 additions and 119 deletions

View File

@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))

View File

@ -20,6 +20,7 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
@ -98,7 +99,9 @@ class EvaluationLoop(DataLoaderLoop):
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
data_fetcher = DataFetcher()
data_fetcher.setup(dataloader)
dataloader_iter = enumerate(data_fetcher)
dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_outputs = self.epoch_loop.run(

View File

@ -86,7 +86,7 @@ class EvaluationEpochLoop(Loop):
"""
void(dl_max_batches, num_dataloaders)
batch_idx, batch = next(dataloader_iter)
batch_idx, (batch, _) = next(dataloader_iter)
if batch is None:
raise StopIteration

View File

@ -17,7 +17,7 @@ import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem
@ -78,7 +78,7 @@ class BaseProfiler(AbstractProfiler):
self._stage: Optional[str] = None
@contextmanager
def profile(self, action_name: str) -> None:
def profile(self, action_name: str) -> Generator:
"""
Yields a context manager to encapsulate the scope of a profiled action.
@ -96,7 +96,7 @@ class BaseProfiler(AbstractProfiler):
finally:
self.stop(action_name)
def profile_iterable(self, iterable, action_name: str) -> None:
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
iterator = iter(iterable)
while True:
try:

View File

@ -24,7 +24,7 @@ import pytorch_lightning as pl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
if _OMEGACONF_AVAILABLE:
@ -348,7 +348,7 @@ class CheckpointConnector:
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
}
if _fault_tolerant_enabled():
if _fault_tolerant_training():
checkpoint["loops"] = self._get_loops_state_dict()
if not weights_only:
@ -451,7 +451,7 @@ class CheckpointConnector:
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
[m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
if _fault_tolerant_enabled()
if _fault_tolerant_training()
else []
)

View File

@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
from typing import Callable, Iterable, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
@ -26,6 +26,7 @@ class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.data_fetcher: Optional[DataFetcher] = None
def on_trainer_init(
self,
@ -59,10 +60,11 @@ class DataConnector:
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
self.data_fetcher = DataFetcher()
self.data_fetcher.setup(train_dataloader)
prefetcher_iter = iter(self.data_fetcher)
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
return profiled_dl
def prepare_data(self) -> None:

View File

@ -34,12 +34,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function
@ -168,7 +169,7 @@ class TrainerDataLoadingMixin(ABC):
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)
@ -180,7 +181,7 @@ class TrainerDataLoadingMixin(ABC):
"drop_last": False,
}
if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)
@ -246,14 +247,20 @@ class TrainerDataLoadingMixin(ABC):
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
if isinstance(dl_kwargs["dataset"], IterableDataset):
del dl_kwargs["sampler"]
del dl_kwargs["batch_sampler"]
if _fault_tolerant_training():
if isinstance(dl_kwargs["dataset"], IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif len(dl_kwargs["dataset"]):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)
return dl_kwargs
@ -308,7 +315,7 @@ class TrainerDataLoadingMixin(ABC):
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_enabled():
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches

View File

@ -15,7 +15,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.utils.data import Dataset
@ -30,7 +30,7 @@ from pytorch_lightning.utilities.auto_restart import (
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
class TensorRunningAccum:
@ -375,7 +375,7 @@ class CombinedLoader:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
"""
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return DataLoaderDict()
state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
@ -541,7 +541,7 @@ class CombinedLoaderIterator:
def next_fn(iterator: Iterator):
batch = next(iterator)
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return batch
# when fault tolerant is enabled, the iterator will return
# `FastForwardSampler` state_dict metadata
@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
new_data.append(x)
return compute_func(new_data)
def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""
it = iter(iterable)
try:
# the iterator may be empty from the beginning
last = next(it)
except StopIteration:
return
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True

View File

@ -77,7 +77,7 @@ from pytorch_lightning.utilities import (
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
@ -1344,7 +1344,7 @@ class Trainer(
)
def _on_exception(self):
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")

View File

@ -21,6 +21,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tup
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -515,7 +516,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
def patch_dataloader_iterator(
dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0
dataloader: DataLoader,
iterator: Iterator,
data_fecher: "pl.utilities.fetching.DataFetcher",
num_batches_fetched: int = 0,
) -> None:
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))
@ -554,7 +558,7 @@ def patch_dataloader_iterator(
num_batches_fetched=num_batches_fetched,
)
]
prefetcher._store_dataloader_iter_state(it, state)
data_fecher._store_dataloader_iter_state(it, state)
return batch
return wrapper

View File

@ -29,10 +29,10 @@ from pytorch_lightning.utilities.auto_restart import (
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
class AbstractFetcher(ABC):
class AbstractDataFetcher(ABC):
"""
This class is used to control batch fetching flow.
@ -61,13 +61,22 @@ class AbstractFetcher(ABC):
self.reset()
def setup(self, dataloader: DataLoader, **kwargs) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
self._add_capture_metadata_collate(dataloader)
self.dataloader = dataloader
if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)
@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
return
if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders
def add_capture_metadata_collate(dataloader: DataLoader):
if not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)
apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)
def add_batch(self, batch) -> None:
self.batches.append(batch)
@ -82,7 +91,7 @@ class AbstractFetcher(ABC):
# cycle_iterator = iterator
iterator = iterator._loader_iter
if isinstance(loader, DataLoader) and _fault_tolerant_enabled():
if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self
patch_dataloader_iterator(loader, iterator, self)
@ -161,7 +170,7 @@ class AbstractFetcher(ABC):
self.done: bool = False
class LightningDataFetcher(AbstractFetcher):
class DataFetcher(AbstractDataFetcher):
"""
This class is used to control batch fetching flow.

View File

@ -103,9 +103,6 @@ else:
_IPU_AVAILABLE = False
def _fault_tolerant_enabled() -> bool:
"""
EXPERIMENTAL
the `reset` function from `_MultiProcessingDataLoaderIter` was introduced in PyTorch 1.7 but we need to mock it.
"""
# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))

View File

@ -1111,6 +1111,7 @@ def test_current_score_when_nan(tmpdir, mode: str):
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo", mode=mode)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[model_checkpoint],
@ -1133,6 +1134,7 @@ def test_hparams_type(tmpdir, hparams_type):
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo")
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,

View File

@ -28,7 +28,7 @@ import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@ -384,7 +384,7 @@ def result_collection_reload(**kwargs):
and final accumulation with Fault Tolerant Training is correct.
"""
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
pytest.skip("Fault tolerant not available")
num_processes = kwargs.get("gpus", 1)

View File

@ -379,6 +379,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
pass
ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress

View File

@ -29,7 +29,6 @@ from pytorch_lightning.trainer.supporters import (
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
prefetch_iterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -80,28 +79,6 @@ def test_none_length_cycle_iterator():
assert item == 0
def test_prefetch_iterator():
"""Test the prefetch_iterator with PyTorch IterableDataset."""
class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3
dataset = IterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == [(1, False), (2, False), (3, True)]
class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])
dataset = EmptyIterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == []
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[

View File

@ -39,7 +39,7 @@ from pytorch_lightning.utilities.auto_restart import (
)
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -100,13 +100,6 @@ def _generate_state(base_seed, worker_id):
return state
@RunIf(min_torch="1.7.0")
@pytest.mark.parametrize("env_setting,expected", [("0", False), ("1", True)])
def test_fault_tolerant_enabled(env_setting, expected):
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": env_setting}):
assert _fault_tolerant_enabled() == expected
def test_fast_forward_getattr():
dataset = range(15)
sampler = SequentialSampler(dataset)
@ -647,10 +640,9 @@ def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset():
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(max_torch="1.6")
@RunIf(max_torch="1.7")
def test_fault_tolerant_not_supported():
with pytest.raises(MisconfigurationException, match="Restart is only supported with torch >= 1.7.0."):
_fault_tolerant_enabled()
assert not _fault_tolerant_training()
def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True):

View File

@ -17,12 +17,12 @@ from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import LightningDataFetcher
from pytorch_lightning.utilities.fetching import DataFetcher
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_prefetch_iterator(use_combined_loader):
"""Test the LightningDataFetcher with PyTorch IterableDataset."""
"""Test the DataFetcher with PyTorch IterableDataset."""
class IterDataset(IterableDataset):
def __iter__(self):
@ -41,7 +41,7 @@ def test_prefetch_iterator(use_combined_loader):
else:
loader = DataLoader(IterDataset())
expected = [(1, False), (2, False), (3, True)]
iterator = LightningDataFetcher(prefetch_batches=prefetch_batches)
iterator = DataFetcher(prefetch_batches=prefetch_batches)
prefetch_batches += 1
assert iterator.prefetch_batches == prefetch_batches
iterator.setup(loader)
@ -66,21 +66,14 @@ def test_prefetch_iterator(use_combined_loader):
return iter([])
dataloader = DataLoader(EmptyIterDataset())
iterator = LightningDataFetcher()
iterator = DataFetcher()
iterator.setup(dataloader)
assert list(iterator) == []
def test_misconfiguration_error():
fetcher = LightningDataFetcher()
with pytest.raises(
MisconfigurationException,
match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.",
):
fetcher.setup(range(10))
fetcher = LightningDataFetcher()
fetcher = DataFetcher()
with pytest.raises(
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
):