Fault Tolerant Manual: Enable the feature (#10707)
This commit is contained in:
parent
30ec4815cb
commit
0066ff0129
|
@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
|
||||
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
|
||||
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))
|
||||
|
||||
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
|
||||
|
|
|
@ -251,9 +251,7 @@ class CaptureMapDataset(Dataset):
|
|||
|
||||
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
|
||||
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
|
||||
state_dict = deepcopy(state_dict)
|
||||
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
|
||||
self._cached_state_dict = state_dict
|
||||
self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers)
|
||||
|
||||
def state_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
return {self.worker_id: {"rng_states": collect_rng_states()}}
|
||||
|
@ -513,14 +511,17 @@ def patch_dataloader_iterator(
|
|||
|
||||
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
|
||||
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
|
||||
faut_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
||||
if not faut_tolerant_mode.is_enabled:
|
||||
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
||||
collate_fn = dataloader.collate_fn
|
||||
if not fault_tolerant_mode.is_enabled or (
|
||||
isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate
|
||||
):
|
||||
return
|
||||
dataloader.collate_fn = partial(
|
||||
_capture_metadata_collate,
|
||||
dataset=dataloader.dataset,
|
||||
collate_fn=dataloader.collate_fn,
|
||||
fault_tolerant_mode=faut_tolerant_mode,
|
||||
collate_fn=collate_fn,
|
||||
fault_tolerant_mode=fault_tolerant_mode,
|
||||
)
|
||||
|
||||
|
||||
|
@ -658,8 +659,7 @@ class _StatefulDataLoaderIter:
|
|||
return indexes
|
||||
|
||||
def _prepare_loader(self, loader):
|
||||
if not isinstance(loader.collate_fn, partial):
|
||||
loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn)
|
||||
_add_capture_metadata_collate(loader)
|
||||
self._loader = loader
|
||||
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
|
||||
self.num_batches_fetched = 0
|
||||
|
@ -723,6 +723,8 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
|
|||
|
||||
def _patch_dataloader_get_iterators() -> None:
|
||||
"""This function is used to replace the DataLoader iterator by their stateful version."""
|
||||
if not _FaultTolerantMode.detect_current_mode().is_manual:
|
||||
return
|
||||
if not hasattr(DataLoader, "_ori_get_iterator"):
|
||||
DataLoader._ori_get_iterator = DataLoader._get_iterator
|
||||
DataLoader._get_iterator = _get_iterator
|
||||
|
|
|
@ -16,7 +16,6 @@ from abc import ABC, abstractmethod
|
|||
from collections.abc import Iterable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -27,6 +26,8 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
|||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_add_capture_metadata_collate,
|
||||
_patch_dataloader_get_iterators,
|
||||
_teardown_dataloader_get_iterators,
|
||||
IteratorState,
|
||||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
|
@ -109,11 +110,7 @@ class AbstractDataFetcher(ABC):
|
|||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloader = dataloader.loaders
|
||||
|
||||
def add_capture_metadata_collate(dataloader: DataLoader):
|
||||
if not isinstance(dataloader.collate_fn, partial):
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)
|
||||
apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)
|
||||
|
||||
def append_batch(self, batch) -> None:
|
||||
self.batches.append(batch)
|
||||
|
@ -206,6 +203,8 @@ class AbstractDataFetcher(ABC):
|
|||
if self.dataloader is None:
|
||||
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
|
||||
self.reset()
|
||||
self._attach_data_fetcher()
|
||||
_patch_dataloader_get_iterators()
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
self._apply_patch()
|
||||
self.prefetching(self.prefetch_batches)
|
||||
|
@ -226,6 +225,7 @@ class AbstractDataFetcher(ABC):
|
|||
if isinstance(self.dataloader, DataLoader):
|
||||
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
|
||||
self.dataloader_iter = None
|
||||
_teardown_dataloader_get_iterators()
|
||||
|
||||
|
||||
class DataFetcher(AbstractDataFetcher):
|
||||
|
|
|
@ -20,7 +20,7 @@ from collections.abc import Iterable
|
|||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from typing import List, Optional
|
||||
from typing import Iterator, List, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY
|
||||
|
||||
|
@ -1317,3 +1317,162 @@ def test_stateful_workers(num_workers):
|
|||
_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
|
||||
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
|
||||
data_fetcher.teardown()
|
||||
|
||||
|
||||
class RandomFaultTolerantDataset(RandomGetItemDataset):
|
||||
def __init__(self, *args, seed: int, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.seed = seed
|
||||
self._cache_state_dict = None
|
||||
self.generator = None
|
||||
self.counter_debug = 0
|
||||
|
||||
@property
|
||||
def worker_id(self):
|
||||
info = get_worker_info()
|
||||
return info.id if info else 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self._cache_state_dict:
|
||||
state_dict = self._cache_state_dict[self.worker_id]
|
||||
self.generator = random.Random()
|
||||
self.generator.setstate(state_dict["random_state"])
|
||||
self._cache_state_dict = None
|
||||
|
||||
if not self.generator:
|
||||
self.generator = random.Random(self.seed + self.worker_id)
|
||||
return torch.tensor(index + self.generator.random())
|
||||
|
||||
def state_dict(self):
|
||||
return {self.worker_id: {"random_state": self.generator.getstate()}}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._cache_state_dict = state_dict
|
||||
|
||||
|
||||
class RandomFaultTolerantSampler(RandomSampler):
|
||||
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
super().__init__(*args, generator=generator, **kwargs)
|
||||
self.counter = 0
|
||||
self.restarting = False
|
||||
|
||||
def state_dict(self):
|
||||
return {"random_state": self.state, "counter": self.counter}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.generator.set_state(state_dict.get("random_state"))
|
||||
self.counter = state_dict["counter"]
|
||||
self.restarting = True
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source) - self.counter
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
n = len(self.data_source)
|
||||
|
||||
self.state = self.generator.get_state()
|
||||
indices = torch.randperm(n, generator=self.generator).tolist()
|
||||
|
||||
if not self.restarting:
|
||||
self.counter = 0
|
||||
else:
|
||||
indices = indices[self.counter :]
|
||||
self.restarting = False
|
||||
|
||||
for index in indices:
|
||||
self.counter += 1
|
||||
yield index
|
||||
|
||||
self.counter = 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["train_dataset_cls", "val_dataset_cls"],
|
||||
[
|
||||
([RandomFaultTolerantDataset, RandomFaultTolerantDataset], [RandomFaultTolerantDataset]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("val_check_interval", [0.5])
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
|
||||
def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_dataset_cls, tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, should_fail: bool = False):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(1, 2)
|
||||
self.should_fail = should_fail
|
||||
self.batches = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.should_fail and batch_idx == 7:
|
||||
raise CustomException
|
||||
self.batches.append(batch)
|
||||
losses = []
|
||||
for b in batch:
|
||||
losses.append(super().training_step(b, batch_idx)["loss"])
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
pass
|
||||
|
||||
validation_epoch_end = None
|
||||
|
||||
def _create_dataloader_kwargs(self, dataset_class, dataset_len, seed, num_workers):
|
||||
dl_kwargs = {}
|
||||
dl_kwargs["dataset"] = dataset_class(dataset_len, 1, seed=seed)
|
||||
dl_kwargs["sampler"] = RandomFaultTolerantSampler(dl_kwargs["dataset"], seed=seed)
|
||||
dl_kwargs["num_workers"] = num_workers
|
||||
dl_kwargs["batch_size"] = 1
|
||||
return dl_kwargs
|
||||
|
||||
def train_dataloader(self):
|
||||
return [
|
||||
DataLoader(
|
||||
**self._create_dataloader_kwargs(
|
||||
dataset_class, 10, seed, seed + 1 if val_check_interval == 1.0 else 0
|
||||
)
|
||||
)
|
||||
for seed, dataset_class in enumerate(train_dataset_cls)
|
||||
]
|
||||
|
||||
def val_dataloader(self):
|
||||
return [
|
||||
DataLoader(**self._create_dataloader_kwargs(dataset_class, 1, seed, 0))
|
||||
for seed, dataset_class in enumerate(val_dataset_cls)
|
||||
]
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
seed_everything(42)
|
||||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
|
||||
trainer.fit(model)
|
||||
total_batches = model.batches
|
||||
total_weight = deepcopy(model.layer.weight)
|
||||
trainer.train_dataloader = None
|
||||
|
||||
seed_everything(42)
|
||||
model = TestModel(should_fail=True)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
|
||||
with suppress(CustomException):
|
||||
trainer.fit(model)
|
||||
trainer.train_dataloader = None
|
||||
failed_batches = model.batches
|
||||
failed_weight = deepcopy(model.layer.weight)
|
||||
|
||||
checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt")
|
||||
assert os.path.exists(checkpoint_path)
|
||||
|
||||
seed_everything(42)
|
||||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
|
||||
trainer.fit(model, ckpt_path=checkpoint_path)
|
||||
trainer.train_dataloader = None
|
||||
restart_batches = model.batches
|
||||
|
||||
torch.testing.assert_allclose(total_batches, failed_batches + restart_batches)
|
||||
assert not torch.equal(total_weight, failed_weight)
|
||||
assert torch.equal(total_weight, model.layer.weight)
|
||||
|
|
Loading…
Reference in New Issue