Fault Tolerant: Add support for fault tolerant dataloader validator (#10465)
This commit is contained in:
parent
88930725dd
commit
e94aff1c5b
|
@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
|
||||
|
||||
- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465))
|
||||
|
||||
|
||||
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from pytorch_lightning.trainer.states import RunningStage
|
|||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
|
||||
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_replace_dataloader_init_method,
|
||||
|
@ -441,6 +441,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if isinstance(dataloader, tuple):
|
||||
dataloader = list(dataloader)
|
||||
self.training_type_plugin.barrier("get_dataloaders")
|
||||
_validate_fault_tolerant_automatic(dataloader, stage)
|
||||
return dataloader
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -16,11 +16,19 @@ from dataclasses import dataclass, field
|
|||
from functools import partial, wraps
|
||||
from random import getstate as python_get_rng_state
|
||||
from random import setstate as python_set_rng_state
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, get_worker_info, Sampler
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
Dataset,
|
||||
DistributedSampler,
|
||||
get_worker_info,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from torch.utils.data.dataloader import (
|
||||
_BaseDataLoaderIter,
|
||||
_MultiProcessingDataLoaderIter,
|
||||
|
@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
|
|||
# get current num workers
|
||||
num_workers = getattr(iter_dataloader, "_num_workers", 0)
|
||||
# as `state_dict` are workers dependent, Lightning doesn't support changing
|
||||
# the `num_workers` for fault tolerant training
|
||||
# the `num_workers` for Fault-tolerance
|
||||
if state_dict["num_workers"] != num_workers:
|
||||
raise MisconfigurationException(
|
||||
f"The provided `num_workers` {num_workers} doesn't match the one used "
|
||||
|
@ -734,13 +742,93 @@ def _patch_dataloader_get_iterators() -> None:
|
|||
|
||||
def _teardown_dataloader_get_iterators() -> None:
|
||||
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
|
||||
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
|
||||
# cleanup the get_iterator replacement in case of Fault-tolerance.
|
||||
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
|
||||
if get_iterator:
|
||||
DataLoader._get_iterator = get_iterator
|
||||
del DataLoader._ori_get_iterator
|
||||
|
||||
|
||||
def _validate_iterable_dataset(dataloader: DataLoader) -> None:
|
||||
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
|
||||
|
||||
dataset = dataloader.dataset
|
||||
|
||||
if getattr(dataset, "__next__", None) is None:
|
||||
raise AttributeError(
|
||||
"Fault-tolerance doesn't support an `IterableDataset` without `__next__` "
|
||||
"method implemented. Hint: We recommend you to move your logic from `__iter__`"
|
||||
" inside and rely on a sampler to perform the sample sampling."
|
||||
)
|
||||
|
||||
samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)}
|
||||
|
||||
if not samplers:
|
||||
raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.")
|
||||
|
||||
sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS]
|
||||
|
||||
if not sampler:
|
||||
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
|
||||
|
||||
if len(sampler) > 1:
|
||||
raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.")
|
||||
|
||||
if type(sampler[0]) is DistributedSampler and sampler.shuffle:
|
||||
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
|
||||
elif type(sampler[0]) is not SequentialSampler:
|
||||
raise TypeError("Only `SequentialSampler` is supported.")
|
||||
|
||||
|
||||
def _validate_map_dataset(dataloader: DataLoader) -> None:
|
||||
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
|
||||
|
||||
sampler = getattr(dataloader, "sampler", None)
|
||||
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
|
||||
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
|
||||
|
||||
batch_sampler = getattr(dataloader, "batch_sampler", None)
|
||||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
|
||||
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
|
||||
|
||||
if type(sampler) is DistributedSampler and sampler.shuffle:
|
||||
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
|
||||
elif type(sampler) is RandomSampler:
|
||||
raise TypeError("Only `SequentialSampler` is supported.")
|
||||
|
||||
|
||||
def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
|
||||
"""This function is used to validate that Fault-tolerance is possible with the user data."""
|
||||
if not _FaultTolerantMode.detect_current_mode().is_automatic:
|
||||
return
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloaders = dataloader.loaders
|
||||
else:
|
||||
dataloaders = dataloader
|
||||
|
||||
dl_loaders = []
|
||||
|
||||
def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
|
||||
nonlocal dl_loaders
|
||||
if isinstance(dataloader, CycleIterator):
|
||||
dataloader = dataloader.loader
|
||||
dl_loaders.append(dataloader)
|
||||
|
||||
apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader)
|
||||
|
||||
if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
|
||||
raise ValueError("Fault-tolerance supports only a single dataloader.")
|
||||
|
||||
for dataloader in dl_loaders:
|
||||
validator_fn = (
|
||||
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
|
||||
)
|
||||
validator_fn(dataloader)
|
||||
|
||||
|
||||
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
|
||||
"""This utility collects the state across processes for a collection of state."""
|
||||
|
||||
|
|
|
@ -34,10 +34,12 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, Se
|
|||
from torch.utils.data._utils.worker import get_worker_info
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerState
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_add_capture_metadata_collate,
|
||||
_collect_states_on_rank_zero_over_collection,
|
||||
|
@ -48,6 +50,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
_SingleProcessDataLoaderIterStateful,
|
||||
_SupportsStateDict,
|
||||
_teardown_dataloader_get_iterators,
|
||||
_validate_fault_tolerant_automatic,
|
||||
CaptureIterableDataset,
|
||||
CaptureMapDataset,
|
||||
FastForwardSampler,
|
||||
|
@ -665,6 +668,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w
|
|||
return dataset
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
|
||||
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
|
||||
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
|
||||
"""This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""
|
||||
|
@ -893,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_
|
|||
return model.seen_batches, model.parameters()
|
||||
|
||||
|
||||
# this test will fail `fault_tolerant` don't support multiple datasets.
|
||||
# this tests works as the dataset is fully deterministic and therefore
|
||||
# there is not overall between the seeds.
|
||||
@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_classes",
|
||||
|
@ -1180,6 +1188,108 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
|
|||
assert "dataloader_state_dict" in state_dict
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
def test_validate_fault_tolerant(tmpdir):
|
||||
def data():
|
||||
return range(10)
|
||||
|
||||
def dataloader():
|
||||
return DataLoader(data())
|
||||
|
||||
_validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING)
|
||||
|
||||
dataloaders = CombinedLoader([dataloader(), dataloader()])
|
||||
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle")
|
||||
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = [dataloader(), dataloader()]
|
||||
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
|
||||
|
||||
dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))]
|
||||
with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))]
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataset = SequentialGetItemDataset(2)
|
||||
dataloaders = [
|
||||
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
|
||||
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
|
||||
]
|
||||
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = [
|
||||
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)),
|
||||
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
|
||||
]
|
||||
with pytest.raises(ValueError, match="Fault-tolerance supports only a single."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = [
|
||||
DataLoader(dataset, sampler=RandomSampler(dataset)),
|
||||
DataLoader(dataset, sampler=SequentialSampler(dataset)),
|
||||
]
|
||||
|
||||
with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
|
||||
|
||||
class CustomRandomSampler(RandomSampler):
|
||||
pass
|
||||
|
||||
dl = DataLoader(data(), sampler=CustomRandomSampler(data()))
|
||||
with pytest.raises(TypeError, match="RandomSampler"):
|
||||
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
|
||||
|
||||
class CustomBatchSampler(BatchSampler):
|
||||
pass
|
||||
|
||||
sampler = Sampler(data())
|
||||
batch_sampler = CustomBatchSampler(sampler, 2, False)
|
||||
dl = DataLoader(data(), batch_sampler=batch_sampler)
|
||||
with pytest.raises(TypeError, match="BatchSampler"):
|
||||
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
|
||||
|
||||
class CustomIterable(IterableDataset):
|
||||
pass
|
||||
|
||||
iterable_dataloader = DataLoader(CustomIterable())
|
||||
with pytest.raises(AttributeError, match="without `__next__` method"):
|
||||
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
|
||||
|
||||
class CustomIterable(IterableDataset):
|
||||
def __next__(self):
|
||||
return torch.tensor(0)
|
||||
|
||||
iterable_dataloader = DataLoader(CustomIterable())
|
||||
with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"):
|
||||
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
|
||||
|
||||
class CustomIterable(IterableDataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sampler = CustomRandomSampler(data())
|
||||
|
||||
def __next__(self):
|
||||
return torch.tensor(0)
|
||||
|
||||
iterable_dataloader = DataLoader(CustomIterable())
|
||||
with pytest.raises(TypeError, match="RandomSampler"):
|
||||
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)
|
||||
|
||||
dataloaders = [iterable_dataloader, DataLoader(CustomIterable())]
|
||||
with pytest.raises(TypeError, match="RandomSampler"):
|
||||
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)
|
||||
|
||||
|
||||
def test_rotate_worker_indices():
|
||||
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
|
||||
state_dict = {0: 0, 1: 1}
|
||||
|
|
Loading…
Reference in New Issue