Fault Tolerant Manual: Enable the feature (#10707)

This commit is contained in:
thomas chaton 2021-11-24 17:36:08 +00:00 committed by GitHub
parent 30ec4815cb
commit 0066ff0129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 178 additions and 16 deletions

View File

@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))

View File

@ -251,9 +251,7 @@ class CaptureMapDataset(Dataset):
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
state_dict = deepcopy(state_dict)
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
self._cached_state_dict = state_dict
self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers)
def state_dict(self) -> Dict[int, Dict[str, Any]]:
return {self.worker_id: {"rng_states": collect_rng_states()}}
@ -513,14 +511,17 @@ def patch_dataloader_iterator(
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
faut_tolerant_mode = _FaultTolerantMode.detect_current_mode()
if not faut_tolerant_mode.is_enabled:
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
collate_fn = dataloader.collate_fn
if not fault_tolerant_mode.is_enabled or (
isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate
):
return
dataloader.collate_fn = partial(
_capture_metadata_collate,
dataset=dataloader.dataset,
collate_fn=dataloader.collate_fn,
fault_tolerant_mode=faut_tolerant_mode,
collate_fn=collate_fn,
fault_tolerant_mode=fault_tolerant_mode,
)
@ -658,8 +659,7 @@ class _StatefulDataLoaderIter:
return indexes
def _prepare_loader(self, loader):
if not isinstance(loader.collate_fn, partial):
loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn)
_add_capture_metadata_collate(loader)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
@ -723,6 +723,8 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
def _patch_dataloader_get_iterators() -> None:
"""This function is used to replace the DataLoader iterator by their stateful version."""
if not _FaultTolerantMode.detect_current_mode().is_manual:
return
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator

View File

@ -16,7 +16,6 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple
import torch
@ -27,6 +26,8 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_patch_dataloader_get_iterators,
_teardown_dataloader_get_iterators,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
@ -109,11 +110,7 @@ class AbstractDataFetcher(ABC):
if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders
def add_capture_metadata_collate(dataloader: DataLoader):
if not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)
apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)
apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)
def append_batch(self, batch) -> None:
self.batches.append(batch)
@ -206,6 +203,8 @@ class AbstractDataFetcher(ABC):
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
self.reset()
self._attach_data_fetcher()
_patch_dataloader_get_iterators()
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching(self.prefetch_batches)
@ -226,6 +225,7 @@ class AbstractDataFetcher(ABC):
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
self.dataloader_iter = None
_teardown_dataloader_get_iterators()
class DataFetcher(AbstractDataFetcher):

View File

@ -20,7 +20,7 @@ from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict
from typing import List, Optional
from typing import Iterator, List, Optional
from unittest import mock
from unittest.mock import ANY
@ -1317,3 +1317,162 @@ def test_stateful_workers(num_workers):
_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
data_fetcher.teardown()
class RandomFaultTolerantDataset(RandomGetItemDataset):
def __init__(self, *args, seed: int, **kwargs):
super().__init__(*args, **kwargs)
self.seed = seed
self._cache_state_dict = None
self.generator = None
self.counter_debug = 0
@property
def worker_id(self):
info = get_worker_info()
return info.id if info else 0
def __getitem__(self, index):
if self._cache_state_dict:
state_dict = self._cache_state_dict[self.worker_id]
self.generator = random.Random()
self.generator.setstate(state_dict["random_state"])
self._cache_state_dict = None
if not self.generator:
self.generator = random.Random(self.seed + self.worker_id)
return torch.tensor(index + self.generator.random())
def state_dict(self):
return {self.worker_id: {"random_state": self.generator.getstate()}}
def load_state_dict(self, state_dict):
self._cache_state_dict = state_dict
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.state, "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
self.restarting = True
def __len__(self):
return len(self.data_source) - self.counter
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter :]
self.restarting = False
for index in indices:
self.counter += 1
yield index
self.counter = 0
@pytest.mark.parametrize(
["train_dataset_cls", "val_dataset_cls"],
[
([RandomFaultTolerantDataset, RandomFaultTolerantDataset], [RandomFaultTolerantDataset]),
],
)
@pytest.mark.parametrize("val_check_interval", [0.5])
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_dataset_cls, tmpdir):
class TestModel(BoringModel):
def __init__(self, should_fail: bool = False):
super().__init__()
self.layer = torch.nn.Linear(1, 2)
self.should_fail = should_fail
self.batches = []
def training_step(self, batch, batch_idx):
if self.should_fail and batch_idx == 7:
raise CustomException
self.batches.append(batch)
losses = []
for b in batch:
losses.append(super().training_step(b, batch_idx)["loss"])
return torch.stack(losses).mean()
def validation_step(self, batch, batch_idx, dataloader_idx=0):
pass
validation_epoch_end = None
def _create_dataloader_kwargs(self, dataset_class, dataset_len, seed, num_workers):
dl_kwargs = {}
dl_kwargs["dataset"] = dataset_class(dataset_len, 1, seed=seed)
dl_kwargs["sampler"] = RandomFaultTolerantSampler(dl_kwargs["dataset"], seed=seed)
dl_kwargs["num_workers"] = num_workers
dl_kwargs["batch_size"] = 1
return dl_kwargs
def train_dataloader(self):
return [
DataLoader(
**self._create_dataloader_kwargs(
dataset_class, 10, seed, seed + 1 if val_check_interval == 1.0 else 0
)
)
for seed, dataset_class in enumerate(train_dataset_cls)
]
def val_dataloader(self):
return [
DataLoader(**self._create_dataloader_kwargs(dataset_class, 1, seed, 0))
for seed, dataset_class in enumerate(val_dataset_cls)
]
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
seed_everything(42)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
trainer.fit(model)
total_batches = model.batches
total_weight = deepcopy(model.layer.weight)
trainer.train_dataloader = None
seed_everything(42)
model = TestModel(should_fail=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
with suppress(CustomException):
trainer.fit(model)
trainer.train_dataloader = None
failed_batches = model.batches
failed_weight = deepcopy(model.layer.weight)
checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(checkpoint_path)
seed_everything(42)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
trainer.fit(model, ckpt_path=checkpoint_path)
trainer.train_dataloader = None
restart_batches = model.batches
torch.testing.assert_allclose(total_batches, failed_batches + restart_batches)
assert not torch.equal(total_weight, failed_weight)
assert torch.equal(total_weight, model.layer.weight)