Fault Tolerant: Add support for fault tolerant dataloader validator (#10465)

This commit is contained in:
thomas chaton 2021-11-26 19:33:47 +00:00 committed by GitHub
parent 88930725dd
commit e94aff1c5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 208 additions and 6 deletions

View File

@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465))
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))

View File

@ -28,7 +28,7 @@ from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
@ -441,6 +441,7 @@ class TrainerDataLoadingMixin(ABC):
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
_validate_fault_tolerant_automatic(dataloader, stage)
return dataloader
@staticmethod

View File

@ -16,11 +16,19 @@ from dataclasses import dataclass, field
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
# get current num workers
num_workers = getattr(iter_dataloader, "_num_workers", 0)
# as `state_dict` are workers dependent, Lightning doesn't support changing
# the `num_workers` for fault tolerant training
# the `num_workers` for Fault-tolerance
if state_dict["num_workers"] != num_workers:
raise MisconfigurationException(
f"The provided `num_workers` {num_workers} doesn't match the one used "
@ -734,13 +742,93 @@ def _patch_dataloader_get_iterators() -> None:
def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
# cleanup the get_iterator replacement in case of Fault-tolerance.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator
def _validate_iterable_dataset(dataloader: DataLoader) -> None:
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
dataset = dataloader.dataset
if getattr(dataset, "__next__", None) is None:
raise AttributeError(
"Fault-tolerance doesn't support an `IterableDataset` without `__next__` "
"method implemented. Hint: We recommend you to move your logic from `__iter__`"
" inside and rely on a sampler to perform the sample sampling."
)
samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)}
if not samplers:
raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.")
sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS]
if not sampler:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
if len(sampler) > 1:
raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.")
if type(sampler[0]) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler[0]) is not SequentialSampler:
raise TypeError("Only `SequentialSampler` is supported.")
def _validate_map_dataset(dataloader: DataLoader) -> None:
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
sampler = getattr(dataloader, "sampler", None)
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
if type(sampler) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler) is RandomSampler:
raise TypeError("Only `SequentialSampler` is supported.")
def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
"""This function is used to validate that Fault-tolerance is possible with the user data."""
if not _FaultTolerantMode.detect_current_mode().is_automatic:
return
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
if isinstance(dataloader, CombinedLoader):
dataloaders = dataloader.loaders
else:
dataloaders = dataloader
dl_loaders = []
def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
nonlocal dl_loaders
if isinstance(dataloader, CycleIterator):
dataloader = dataloader.loader
dl_loaders.append(dataloader)
apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader)
if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
raise ValueError("Fault-tolerance supports only a single dataloader.")
for dataloader in dl_loaders:
validator_fn = (
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
)
validator_fn(dataloader)
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
"""This utility collects the state across processes for a collection of state."""

View File

@ -34,10 +34,12 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, Se
from torch.utils.data._utils.worker import get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler
import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_collect_states_on_rank_zero_over_collection,
@ -48,6 +50,7 @@ from pytorch_lightning.utilities.auto_restart import (
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
_teardown_dataloader_get_iterators,
_validate_fault_tolerant_automatic,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
@ -665,6 +668,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w
return dataset
@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
"""This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""
@ -893,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_
return model.seen_batches, model.parameters()
# this test will fail `fault_tolerant` don't support multiple datasets.
# this tests works as the dataset is fully deterministic and therefore
# there is not overall between the seeds.
@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize(
"dataset_classes",
@ -1180,6 +1188,108 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_validate_fault_tolerant(tmpdir):
def data():
return range(10)
def dataloader():
return DataLoader(data())
_validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING)
dataloaders = CombinedLoader([dataloader(), dataloader()])
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle")
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataloaders = [dataloader(), dataloader()]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))]
with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))]
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataset = SequentialGetItemDataset(2)
dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
dataloaders = [
DataLoader(dataset, sampler=RandomSampler(dataset)),
DataLoader(dataset, sampler=SequentialSampler(dataset)),
]
with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
class CustomRandomSampler(RandomSampler):
pass
dl = DataLoader(data(), sampler=CustomRandomSampler(data()))
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
class CustomBatchSampler(BatchSampler):
pass
sampler = Sampler(data())
batch_sampler = CustomBatchSampler(sampler, 2, False)
dl = DataLoader(data(), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match="BatchSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
class CustomIterable(IterableDataset):
pass
iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(AttributeError, match="without `__next__` method"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
class CustomIterable(IterableDataset):
def __next__(self):
return torch.tensor(0)
iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
class CustomIterable(IterableDataset):
def __init__(self):
super().__init__()
self.sampler = CustomRandomSampler(data())
def __next__(self):
return torch.tensor(0)
iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
dataloaders = [iterable_dataloader, DataLoader(CustomIterable())]
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
def test_rotate_worker_indices():
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
state_dict = {0: 0, 1: 1}