Enable saving and loading stateful DataLoaders in Trainer (#19361)

This commit is contained in:
awaelchli 2024-02-01 03:11:19 +01:00 committed by GitHub
parent 5d178d07b7
commit 34a34a0754
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 1 deletions

View File

@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381))
- Added support for saving and loading stateful training DataLoaders ([#19361](https://github.com/Lightning-AI/lightning/pull/19361))
### Changed
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
from typing_extensions import override
@ -94,6 +94,7 @@ class _FitLoop(_Loop):
self._data_source = _DataLoaderSource(None, "train_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")
@ -255,6 +256,8 @@ class _FitLoop(_Loop):
combined_loader.limits = limits
self._load_combined_loader_states()
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
self._data_fetcher.setup(combined_loader)
iter(self._data_fetcher) # creates the iterator inside the fetcher
@ -409,9 +412,27 @@ class _FitLoop(_Loop):
self._data_fetcher = None
self.epoch_loop.teardown()
@override
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
state_dict["combined_loader"] = loader_states
return state_dict
@override
def on_load_checkpoint(self, state_dict: Dict) -> None:
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
super().on_load_checkpoint(state_dict)
def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
def _iteration_based_training(self) -> bool:
return self.trainer.max_steps != -1
def _load_combined_loader_states(self) -> None:
if not self.restarting or not self._combined_loader_states_to_load or self._combined_loader is None:
return
self._combined_loader._load_state_dicts(self._combined_loader_states_to_load)
self._combined_loader_states_to_load = [] # release memory

View File

@ -19,6 +19,7 @@ from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDat
from typing_extensions import Self, TypedDict, override
from lightning.fabric.utilities.data import sized_len
from lightning.fabric.utilities.types import _Stateful
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten
_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx
@ -374,6 +375,24 @@ class CombinedLoader(Iterable):
fn = _SUPPORTED_MODES[self._mode]["fn"]
return fn(lengths)
def _state_dicts(self) -> List[Dict[str, Any]]:
"""Returns the list of state dicts for iterables in `self.flattened` that are stateful."""
return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)]
def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None:
"""Loads the state dicts for iterables in `self.flattened` that are stateful."""
if not states:
return
stateful_loaders = [loader for loader in self.flattened if isinstance(loader, _Stateful)]
if len(stateful_loaders) != len(states):
raise RuntimeError(
f"The CombinedLoader has {len(stateful_loaders)} stateful loaders, but found {len(states)} states"
" in the checkpoint. Please make sure you define the same dataloaders that were used when saving"
" the checkpoint."
)
for loader, state_dict in zip(stateful_loaders, states):
loader.load_state_dict(state_dict)
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
if hasattr(dataloader, "_iterator"):

View File

@ -24,6 +24,7 @@ from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCh
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.progress import _BaseProgress
from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter
from tests_pytorch.helpers.runif import RunIf
@ -882,3 +883,94 @@ def test_validation_during_gradient_accumulation_window(tmp_path):
)
trainer.fit(model)
assert model.ran_assert
class NotStatefulIterable:
def __init__(self, start=0):
self.index = start
def __iter__(self):
for i in range(self.index, len(self)):
self.index = i
yield self.index
def __len__(self):
return 10
class StatefulIterable(NotStatefulIterable):
def state_dict(self):
return {"index": self.index}
def load_state_dict(self, state_dict):
self.index = state_dict["index"] + 1
@pytest.mark.parametrize(
("train_dataloader_factory", "has_state", "batches_before", "batches_after"),
[
# No dataloader
(lambda: [], False, [], []),
# Single stateful DataLoader
(lambda: StatefulIterable(), True, [0, 1], [2, 3]),
# Single, not stateful DataLoader
(lambda: CombinedLoader(NotStatefulIterable()), False, [0, 1], [0, 1]),
# Single stateful DataLoader
(lambda: CombinedLoader(StatefulIterable()), True, [0, 1], [2, 3]),
# Multiple stateful DataLoaders
(lambda: CombinedLoader([StatefulIterable(3), StatefulIterable(1)]), True, [[3, 1], [4, 2]], [[5, 3], [6, 4]]),
# Mix of stateful and not stateful DataLoaders
(
lambda: CombinedLoader([NotStatefulIterable(3), StatefulIterable(1), NotStatefulIterable(2)]),
True,
[[3, 1, 2], [4, 2, 3]],
[[3, 3, 2], [4, 4, 3]],
),
],
)
def test_fit_loop_save_and_restore_dataloaders(
train_dataloader_factory, has_state, batches_before, batches_after, tmp_path
):
"""Test that the CheckpointConnector saves the state of stateful dataloaders."""
class DummyModel(BoringModel):
def __init__(self):
super().__init__()
self.seen_data = []
def training_step(self, batch, batch_idx):
self.seen_data.append(batch)
print(batch)
def train_dataloader(self):
return train_dataloader_factory()
trainer_kwargs = {
"default_root_dir": tmp_path,
"accelerator": "cpu",
"enable_checkpointing": False,
"enable_model_summary": False,
"enable_progress_bar": False,
"logger": False,
"num_sanity_val_steps": 0,
}
# Train for 2 steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=2)
trainer.fit(model)
assert model.seen_data == batches_before
# Save a checkpoint
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
if has_state:
assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
else:
assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]
# Restore training from step 2 and continue 2 more steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=4)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
assert model.seen_data == batches_after

View File

@ -14,9 +14,11 @@
import math
import pickle
from typing import Any, NamedTuple, Sequence, get_args
from unittest.mock import Mock
import pytest
import torch
from lightning.fabric.utilities.types import _Stateful
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.utilities.combined_loader import (
@ -602,3 +604,57 @@ def test_combined_loader_can_be_pickled():
# no error
pickle.dumps(cl)
def test_state_dicts():
state1, state2, state3 = Mock(), Mock(), Mock()
stateful1 = Mock(spec=_Stateful, state_dict=Mock(return_value=state1))
stateful2 = Mock(spec=_Stateful, state_dict=Mock(return_value=state2))
stateful3 = Mock(spec=_Stateful, state_dict=Mock(return_value=state3))
cl = CombinedLoader([])
assert cl._state_dicts() == []
cl = CombinedLoader([range(2)])
assert cl._state_dicts() == []
cl = CombinedLoader([stateful1])
assert cl._state_dicts() == [state1]
cl = CombinedLoader([range(2), stateful1])
assert cl._state_dicts() == [state1]
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
assert cl._state_dicts() == [state1, state2]
cl = CombinedLoader({"a": [range(2), stateful1], "b": [stateful2], "c": stateful3})
assert cl._state_dicts() == [state1, state2, state3]
def test_load_state_dicts():
stateful1 = Mock(spec=_Stateful)
stateful2 = Mock(spec=_Stateful)
state1 = Mock()
state2 = Mock()
# 0 stateful loaders, 1 state to load
cl = CombinedLoader([range(2), range(3)])
with pytest.raises(RuntimeError, match="has 0 stateful loaders, but found 1 states"):
cl._load_state_dicts([{"state": 0}])
# 1 stateful loader, 0 states to load
cl = CombinedLoader([stateful1, range(3)])
cl._load_state_dicts([])
stateful1.load_state_dict.assert_not_called()
# 1 stateful loader, 1 state to load
cl = CombinedLoader([range(2), stateful1, range(3)])
cl._load_state_dicts([state1])
stateful1.load_state_dict.assert_called_with(state1)
stateful1.reset_mock()
# 1 stateful loader, 2 states to load
cl = CombinedLoader([range(2), stateful1, range(3)])
with pytest.raises(RuntimeError, match="has 1 stateful loaders, but found 2 states"):
cl._load_state_dicts([state1, state2])
# 2 stateful loaders, 2 states to load
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
cl._load_state_dicts([state1, state2])
stateful1.load_state_dict.assert_called_with(state1)
stateful2.load_state_dict.assert_called_with(state2)