Enable DataLoader state restoration for the evaluation loop (#9563)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
ce00053002
commit
9148a13de0
|
@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
|
||||
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
|
||||
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
|
||||
* Added support for restarting within Evaluation Loop ([#9563](https://github.com/PyTorchLightning/pytorch-lightning/pull/9563))
|
||||
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))
|
||||
* Support skipping to validation during fitting ([#9681](https://github.com/PyTorchLightning/pytorch-lightning/pull/9681))
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from deprecate.utils import void
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
@ -19,6 +20,8 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
|
||||
from pytorch_lightning.utilities.auto_restart import reload_dataloader_state_dict
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
|
@ -34,6 +37,8 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._results = ResultCollection(training=False)
|
||||
self._max_batches: Optional[Union[int, Sequence[int]]] = None
|
||||
self._has_run: bool = False
|
||||
self._data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
self._dataloader_state_dict: Dict[str, Any] = None
|
||||
|
||||
@property
|
||||
def num_dataloaders(self) -> int:
|
||||
|
@ -101,7 +106,9 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
dataloader_idx: int = self.current_dataloader_idx
|
||||
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
|
||||
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
|
||||
self._data_fetcher = dataloader = self.trainer.data_connector.get_profiled_dataloader(
|
||||
dataloader, dataloader_idx=dataloader_idx
|
||||
)
|
||||
|
||||
dl_max_batches = self._max_batches[dataloader_idx]
|
||||
|
||||
|
@ -121,6 +128,9 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
# free memory
|
||||
self.outputs = []
|
||||
|
||||
# drop reference to iterator.
|
||||
self._data_fetcher = None
|
||||
|
||||
# with a single dataloader don't pass a 2D list
|
||||
if len(outputs) > 0 and self.num_dataloaders == 1:
|
||||
outputs = outputs[0]
|
||||
|
@ -167,6 +177,10 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
|
||||
self.trainer.reset_val_dataloader()
|
||||
|
||||
if not self.trainer.sanity_checking and self._dataloader_state_dict:
|
||||
reload_dataloader_state_dict(self.dataloaders[self.current_dataloader_idx], self._dataloader_state_dict)
|
||||
self._dataloader_state_dict = None
|
||||
|
||||
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Runs ``on_{validation/test}_start`` hooks."""
|
||||
assert self._results is not None
|
||||
|
@ -239,3 +253,13 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self.trainer.call_hook(hook_name)
|
||||
self.trainer.call_hook("on_epoch_end")
|
||||
self.trainer.logger_connector.on_epoch_end()
|
||||
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
state_dict = super().on_save_checkpoint()
|
||||
if self._data_fetcher is not None and self._data_fetcher.dataloader_iter is not None:
|
||||
state_dict["dataloader_state_dict"] = asdict(self._data_fetcher.dataloader_iter.previous_state)
|
||||
return state_dict
|
||||
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
|
||||
|
|
|
@ -249,7 +249,8 @@ class FitLoop(Loop):
|
|||
def on_save_checkpoint(self) -> Dict:
|
||||
state_dict = super().on_save_checkpoint()
|
||||
# TODO: update has_completed to its proper value
|
||||
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
|
||||
if self.trainer.train_dataloader is not None:
|
||||
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
|
||||
return state_dict
|
||||
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
|
|
|
@ -24,12 +24,9 @@ from torch.utils.data.dataset import IterableDataset
|
|||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_find_fast_forward_samplers,
|
||||
CaptureIterableDataset,
|
||||
CaptureMapDataset,
|
||||
IteratorState,
|
||||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
reload_dataloader_state_dict,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -400,37 +397,7 @@ class CombinedLoader:
|
|||
if isinstance(dataloader, CycleIterator):
|
||||
dataloader = dataloader_to_iter_on.loader
|
||||
|
||||
dataset = dataloader.dataset
|
||||
|
||||
# We reload the states before creating the workers
|
||||
# The specific type of dataset will then decide if the state should be applied before or after
|
||||
# spawning the workers
|
||||
if isinstance(dataset, CaptureMapDataset):
|
||||
iterator_state = state_dict["state"][0]
|
||||
|
||||
if not isinstance(iterator_state, IteratorState):
|
||||
iterator_state = IteratorState.from_state_dict(iterator_state)
|
||||
|
||||
# reload sampler state
|
||||
ff_sampler = _find_fast_forward_samplers(dataloader)
|
||||
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
||||
# reload dataset state
|
||||
dataset.load_state_dict(
|
||||
iterator_state.dataset_state,
|
||||
latest_worker_id=state_dict["latest_worker_id"],
|
||||
num_workers=iterator_state.num_workers,
|
||||
)
|
||||
|
||||
elif isinstance(dataset, CaptureIterableDataset):
|
||||
dataset_dict = {
|
||||
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
|
||||
}
|
||||
dataset.load_state_dict(dataset_dict)
|
||||
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
|
||||
)
|
||||
reload_dataloader_state_dict(dataloader, state_dict)
|
||||
|
||||
# We finally spawned the workers if any.
|
||||
it = iter(dataloader_to_iter_on)
|
||||
|
|
|
@ -27,6 +27,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
class FastForwardSampler(Sampler):
|
||||
|
@ -545,3 +546,37 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
|
|||
dataloader.collate_fn = partial(
|
||||
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
|
||||
)
|
||||
|
||||
|
||||
def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
"""Utility to reload state_dict within dataloader for fault tolerance."""
|
||||
|
||||
if not _fault_tolerant_training():
|
||||
return
|
||||
|
||||
dataset = dataloader.dataset
|
||||
|
||||
if isinstance(dataset, CaptureMapDataset):
|
||||
iterator_state = state_dict["state"][0]
|
||||
|
||||
if not isinstance(iterator_state, IteratorState):
|
||||
iterator_state = IteratorState.from_state_dict(iterator_state)
|
||||
|
||||
# reload sampler state
|
||||
ff_sampler = _find_fast_forward_samplers(dataloader)
|
||||
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
||||
|
||||
# reload dataset state
|
||||
dataset.load_state_dict(
|
||||
iterator_state.dataset_state,
|
||||
latest_worker_id=state_dict["latest_worker_id"],
|
||||
num_workers=iterator_state.num_workers,
|
||||
)
|
||||
|
||||
elif isinstance(dataset, CaptureIterableDataset):
|
||||
dataset.load_state_dict(
|
||||
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
|
||||
)
|
||||
|
||||
else:
|
||||
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
|
||||
|
|
|
@ -15,6 +15,7 @@ import math
|
|||
import os
|
||||
import random
|
||||
import random as python_random
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
|
@ -970,3 +971,92 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult
|
|||
for w0, w1 in zip(weights0, weights1):
|
||||
assert w0 is not w1
|
||||
assert torch.allclose(w0, w1)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@RunIf(min_torch="1.7.0")
|
||||
@pytest.mark.parametrize(
|
||||
["train_datasets", "val_datasets"],
|
||||
[
|
||||
([RandomGetItemDataset], [RandomGetItemDataset]),
|
||||
([RandomGetItemDataset], [RandomGetItemDataset, RandomGetItemDataset]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"val_check_interval",
|
||||
[
|
||||
pytest.param(
|
||||
0.5,
|
||||
marks=pytest.mark.xfail(
|
||||
reason=(
|
||||
"TODO: the `train_dataloader` random state overrides the validation state when restarting training"
|
||||
)
|
||||
),
|
||||
),
|
||||
1.0,
|
||||
],
|
||||
)
|
||||
def test_auto_restart_within_validation_loop(train_datasets, val_datasets, val_check_interval, tmpdir):
|
||||
n_val_dataloaders = len(val_datasets)
|
||||
stop_dataloader = n_val_dataloaders - 1
|
||||
stop_batch = 1
|
||||
|
||||
class ValidationLoopTestModel(LightningModule):
|
||||
def __init__(self, should_fail):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(1, 2)
|
||||
self.should_fail = should_fail
|
||||
self.training_batches = []
|
||||
self.validation_batches = defaultdict(list)
|
||||
|
||||
def step(self, batch):
|
||||
return sum(self.layer(b).sum() for b in batch)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.training_batches.append(batch)
|
||||
return self.step(batch)
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
if self.should_fail and stop_dataloader == dataloader_idx and batch_idx == stop_batch:
|
||||
raise CustomException
|
||||
self.validation_batches[dataloader_idx].append(batch)
|
||||
return self.step(batch)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
|
||||
def train_dataloader(self):
|
||||
return [DataLoader(cls(4, 1)) for cls in train_datasets]
|
||||
|
||||
def val_dataloader(self):
|
||||
return [DataLoader(cls(4, 1)) for cls in val_datasets]
|
||||
|
||||
def run(should_fail, resume):
|
||||
if not resume:
|
||||
seed_everything(42)
|
||||
|
||||
model = ValidationLoopTestModel(should_fail)
|
||||
|
||||
resume_from_checkpoint = str(tmpdir / ".pl_auto_save.ckpt") if resume else None
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=val_check_interval,
|
||||
num_sanity_val_steps=0,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
if should_fail:
|
||||
with pytest.raises(CustomException):
|
||||
trainer.fit(model)
|
||||
else:
|
||||
trainer.fit(model)
|
||||
|
||||
return model.training_batches, model.validation_batches
|
||||
|
||||
total_train_batches, total_val_batches = run(should_fail=False, resume=False)
|
||||
pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False)
|
||||
post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True)
|
||||
|
||||
torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
|
||||
for k in total_val_batches:
|
||||
torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])
|
||||
|
|
Loading…
Reference in New Issue