From 0066ff012962ffa920e9bbd750f4edbea05d4505 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Nov 2021 17:36:08 +0000 Subject: [PATCH] Fault Tolerant Manual: Enable the feature (#10707) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/auto_restart.py | 20 +-- pytorch_lightning/utilities/fetching.py | 12 +- tests/utilities/test_auto_restart.py | 161 +++++++++++++++++++- 4 files changed, 178 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 460834dba7..59d29e1836 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) + * Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707)) - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 074090f10e..9f99634bd1 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -251,9 +251,7 @@ class CaptureMapDataset(Dataset): def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: # as workers aren't available, the ``state_dict``` is cached until workers are made available. - state_dict = deepcopy(state_dict) - state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers) - self._cached_state_dict = state_dict + self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers) def state_dict(self) -> Dict[int, Dict[str, Any]]: return {self.worker_id: {"rng_states": collect_rng_states()}} @@ -513,14 +511,17 @@ def patch_dataloader_iterator( def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" - faut_tolerant_mode = _FaultTolerantMode.detect_current_mode() - if not faut_tolerant_mode.is_enabled: + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + collate_fn = dataloader.collate_fn + if not fault_tolerant_mode.is_enabled or ( + isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate + ): return dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, - collate_fn=dataloader.collate_fn, - fault_tolerant_mode=faut_tolerant_mode, + collate_fn=collate_fn, + fault_tolerant_mode=fault_tolerant_mode, ) @@ -658,8 +659,7 @@ class _StatefulDataLoaderIter: return indexes def _prepare_loader(self, loader): - if not isinstance(loader.collate_fn, partial): - loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn) + _add_capture_metadata_collate(loader) self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 @@ -723,6 +723,8 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": def _patch_dataloader_get_iterators() -> None: """This function is used to replace the DataLoader iterator by their stateful version.""" + if not _FaultTolerantMode.detect_current_mode().is_manual: + return if not hasattr(DataLoader, "_ori_get_iterator"): DataLoader._ori_get_iterator = DataLoader._get_iterator DataLoader._get_iterator = _get_iterator diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f5bb4be032..7ac0bfa00c 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -16,7 +16,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from contextlib import contextmanager from copy import deepcopy -from functools import partial from typing import Any, Callable, Generator, List, Optional, Tuple import torch @@ -27,6 +26,8 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, + _patch_dataloader_get_iterators, + _teardown_dataloader_get_iterators, IteratorState, MergedIteratorState, patch_dataloader_iterator, @@ -109,11 +110,7 @@ class AbstractDataFetcher(ABC): if isinstance(dataloader, CombinedLoader): dataloader = dataloader.loaders - def add_capture_metadata_collate(dataloader: DataLoader): - if not isinstance(dataloader.collate_fn, partial): - _add_capture_metadata_collate(dataloader) - - apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate) + apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate) def append_batch(self, batch) -> None: self.batches.append(batch) @@ -206,6 +203,8 @@ class AbstractDataFetcher(ABC): if self.dataloader is None: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") self.reset() + self._attach_data_fetcher() + _patch_dataloader_get_iterators() self.dataloader_iter = iter(self.dataloader) self._apply_patch() self.prefetching(self.prefetch_batches) @@ -226,6 +225,7 @@ class AbstractDataFetcher(ABC): if isinstance(self.dataloader, DataLoader): CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) self.dataloader_iter = None + _teardown_dataloader_get_iterators() class DataFetcher(AbstractDataFetcher): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 47f5deb344..c69b70b65b 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -20,7 +20,7 @@ from collections.abc import Iterable from contextlib import suppress from copy import deepcopy from dataclasses import asdict -from typing import List, Optional +from typing import Iterator, List, Optional from unittest import mock from unittest.mock import ANY @@ -1317,3 +1317,162 @@ def test_stateful_workers(num_workers): _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) assert dataloader.sampler.counter == dataloader.dataset.counter == 1 data_fetcher.teardown() + + +class RandomFaultTolerantDataset(RandomGetItemDataset): + def __init__(self, *args, seed: int, **kwargs): + super().__init__(*args, **kwargs) + self.seed = seed + self._cache_state_dict = None + self.generator = None + self.counter_debug = 0 + + @property + def worker_id(self): + info = get_worker_info() + return info.id if info else 0 + + def __getitem__(self, index): + if self._cache_state_dict: + state_dict = self._cache_state_dict[self.worker_id] + self.generator = random.Random() + self.generator.setstate(state_dict["random_state"]) + self._cache_state_dict = None + + if not self.generator: + self.generator = random.Random(self.seed + self.worker_id) + return torch.tensor(index + self.generator.random()) + + def state_dict(self): + return {self.worker_id: {"random_state": self.generator.getstate()}} + + def load_state_dict(self, state_dict): + self._cache_state_dict = state_dict + + +class RandomFaultTolerantSampler(RandomSampler): + def __init__(self, *args, seed: int = 0, generator=None, **kwargs): + generator = torch.Generator().manual_seed(seed) + super().__init__(*args, generator=generator, **kwargs) + self.counter = 0 + self.restarting = False + + def state_dict(self): + return {"random_state": self.state, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.generator.set_state(state_dict.get("random_state")) + self.counter = state_dict["counter"] + self.restarting = True + + def __len__(self): + return len(self.data_source) - self.counter + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + + self.state = self.generator.get_state() + indices = torch.randperm(n, generator=self.generator).tolist() + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter :] + self.restarting = False + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + + +@pytest.mark.parametrize( + ["train_dataset_cls", "val_dataset_cls"], + [ + ([RandomFaultTolerantDataset, RandomFaultTolerantDataset], [RandomFaultTolerantDataset]), + ], +) +@pytest.mark.parametrize("val_check_interval", [0.5]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) +def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_dataset_cls, tmpdir): + class TestModel(BoringModel): + def __init__(self, should_fail: bool = False): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.should_fail = should_fail + self.batches = [] + + def training_step(self, batch, batch_idx): + if self.should_fail and batch_idx == 7: + raise CustomException + self.batches.append(batch) + losses = [] + for b in batch: + losses.append(super().training_step(b, batch_idx)["loss"]) + return torch.stack(losses).mean() + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + pass + + validation_epoch_end = None + + def _create_dataloader_kwargs(self, dataset_class, dataset_len, seed, num_workers): + dl_kwargs = {} + dl_kwargs["dataset"] = dataset_class(dataset_len, 1, seed=seed) + dl_kwargs["sampler"] = RandomFaultTolerantSampler(dl_kwargs["dataset"], seed=seed) + dl_kwargs["num_workers"] = num_workers + dl_kwargs["batch_size"] = 1 + return dl_kwargs + + def train_dataloader(self): + return [ + DataLoader( + **self._create_dataloader_kwargs( + dataset_class, 10, seed, seed + 1 if val_check_interval == 1.0 else 0 + ) + ) + for seed, dataset_class in enumerate(train_dataset_cls) + ] + + def val_dataloader(self): + return [ + DataLoader(**self._create_dataloader_kwargs(dataset_class, 1, seed, 0)) + for seed, dataset_class in enumerate(val_dataset_cls) + ] + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + seed_everything(42) + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + trainer.fit(model) + total_batches = model.batches + total_weight = deepcopy(model.layer.weight) + trainer.train_dataloader = None + + seed_everything(42) + model = TestModel(should_fail=True) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + with suppress(CustomException): + trainer.fit(model) + trainer.train_dataloader = None + failed_batches = model.batches + failed_weight = deepcopy(model.layer.weight) + + checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) + + seed_everything(42) + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + trainer.fit(model, ckpt_path=checkpoint_path) + trainer.train_dataloader = None + restart_batches = model.batches + + torch.testing.assert_allclose(total_batches, failed_batches + restart_batches) + assert not torch.equal(total_weight, failed_weight) + assert torch.equal(total_weight, model.layer.weight)