Enable DataLoader state restoration for the evaluation loop (#9563)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-09-24 17:21:00 +01:00 committed by GitHub
parent ce00053002
commit 9148a13de0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 156 additions and 38 deletions

View File

@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
* Added support for restarting within Evaluation Loop ([#9563](https://github.com/PyTorchLightning/pytorch-lightning/pull/9563))
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))
* Support skipping to validation during fitting ([#9681](https://github.com/PyTorchLightning/pytorch-lightning/pull/9681))

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Sequence, Union
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
@ -19,6 +20,8 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.auto_restart import reload_dataloader_state_dict
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
@ -34,6 +37,8 @@ class EvaluationLoop(DataLoaderLoop):
self._results = ResultCollection(training=False)
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = None
@property
def num_dataloaders(self) -> int:
@ -101,7 +106,9 @@ class EvaluationLoop(DataLoaderLoop):
dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
self._data_fetcher = dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]
@ -121,6 +128,9 @@ class EvaluationLoop(DataLoaderLoop):
# free memory
self.outputs = []
# drop reference to iterator.
self._data_fetcher = None
# with a single dataloader don't pass a 2D list
if len(outputs) > 0 and self.num_dataloaders == 1:
outputs = outputs[0]
@ -167,6 +177,10 @@ class EvaluationLoop(DataLoaderLoop):
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader()
if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(self.dataloaders[self.current_dataloader_idx], self._dataloader_state_dict)
self._dataloader_state_dict = None
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
@ -239,3 +253,13 @@ class EvaluationLoop(DataLoaderLoop):
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
if self._data_fetcher is not None and self._data_fetcher.dataloader_iter is not None:
state_dict["dataloader_state_dict"] = asdict(self._data_fetcher.dataloader_iter.previous_state)
return state_dict
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})

View File

@ -249,7 +249,8 @@ class FitLoop(Loop):
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# TODO: update has_completed to its proper value
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
if self.trainer.train_dataloader is not None:
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict
def on_load_checkpoint(self, state_dict: Dict) -> None:

View File

@ -24,12 +24,9 @@ from torch.utils.data.dataset import IterableDataset
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_find_fast_forward_samplers,
CaptureIterableDataset,
CaptureMapDataset,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
reload_dataloader_state_dict,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -400,37 +397,7 @@ class CombinedLoader:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader
dataset = dataloader.dataset
# We reload the states before creating the workers
# The specific type of dataset will then decide if the state should be applied before or after
# spawning the workers
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)
elif isinstance(dataset, CaptureIterableDataset):
dataset_dict = {
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
}
dataset.load_state_dict(dataset_dict)
else:
raise MisconfigurationException(
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
)
reload_dataloader_state_dict(dataloader, state_dict)
# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)

View File

@ -27,6 +27,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
class FastForwardSampler(Sampler):
@ -545,3 +546,37 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)
def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""
if not _fault_tolerant_training():
return
dataset = dataloader.dataset
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)
elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)
else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")

View File

@ -15,6 +15,7 @@ import math
import os
import random
import random as python_random
from collections import defaultdict
from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
@ -970,3 +971,92 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult
for w0, w1 in zip(weights0, weights1):
assert w0 is not w1
assert torch.allclose(w0, w1)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
@pytest.mark.parametrize(
["train_datasets", "val_datasets"],
[
([RandomGetItemDataset], [RandomGetItemDataset]),
([RandomGetItemDataset], [RandomGetItemDataset, RandomGetItemDataset]),
],
)
@pytest.mark.parametrize(
"val_check_interval",
[
pytest.param(
0.5,
marks=pytest.mark.xfail(
reason=(
"TODO: the `train_dataloader` random state overrides the validation state when restarting training"
)
),
),
1.0,
],
)
def test_auto_restart_within_validation_loop(train_datasets, val_datasets, val_check_interval, tmpdir):
n_val_dataloaders = len(val_datasets)
stop_dataloader = n_val_dataloaders - 1
stop_batch = 1
class ValidationLoopTestModel(LightningModule):
def __init__(self, should_fail):
super().__init__()
self.layer = torch.nn.Linear(1, 2)
self.should_fail = should_fail
self.training_batches = []
self.validation_batches = defaultdict(list)
def step(self, batch):
return sum(self.layer(b).sum() for b in batch)
def training_step(self, batch, batch_idx):
self.training_batches.append(batch)
return self.step(batch)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if self.should_fail and stop_dataloader == dataloader_idx and batch_idx == stop_batch:
raise CustomException
self.validation_batches[dataloader_idx].append(batch)
return self.step(batch)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
return [DataLoader(cls(4, 1)) for cls in train_datasets]
def val_dataloader(self):
return [DataLoader(cls(4, 1)) for cls in val_datasets]
def run(should_fail, resume):
if not resume:
seed_everything(42)
model = ValidationLoopTestModel(should_fail)
resume_from_checkpoint = str(tmpdir / ".pl_auto_save.ckpt") if resume else None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=val_check_interval,
num_sanity_val_steps=0,
resume_from_checkpoint=resume_from_checkpoint,
)
if should_fail:
with pytest.raises(CustomException):
trainer.fit(model)
else:
trainer.fit(model)
return model.training_batches, model.validation_batches
total_train_batches, total_val_batches = run(should_fail=False, resume=False)
pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False)
post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True)
torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
for k in total_val_batches:
torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])