Enable saving and loading stateful DataLoaders in Trainer (#19361)
This commit is contained in:
parent
5d178d07b7
commit
34a34a0754
|
@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381))
|
- The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381))
|
||||||
|
|
||||||
|
|
||||||
|
- Added support for saving and loading stateful training DataLoaders ([#19361](https://github.com/Lightning-AI/lightning/pull/19361))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
|
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
@ -94,6 +94,7 @@ class _FitLoop(_Loop):
|
||||||
|
|
||||||
self._data_source = _DataLoaderSource(None, "train_dataloader")
|
self._data_source = _DataLoaderSource(None, "train_dataloader")
|
||||||
self._combined_loader: Optional[CombinedLoader] = None
|
self._combined_loader: Optional[CombinedLoader] = None
|
||||||
|
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
|
||||||
self._data_fetcher: Optional[_DataFetcher] = None
|
self._data_fetcher: Optional[_DataFetcher] = None
|
||||||
self._last_train_dl_reload_epoch = float("-inf")
|
self._last_train_dl_reload_epoch = float("-inf")
|
||||||
|
|
||||||
|
@ -255,6 +256,8 @@ class _FitLoop(_Loop):
|
||||||
|
|
||||||
combined_loader.limits = limits
|
combined_loader.limits = limits
|
||||||
|
|
||||||
|
self._load_combined_loader_states()
|
||||||
|
|
||||||
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
|
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
|
||||||
self._data_fetcher.setup(combined_loader)
|
self._data_fetcher.setup(combined_loader)
|
||||||
iter(self._data_fetcher) # creates the iterator inside the fetcher
|
iter(self._data_fetcher) # creates the iterator inside the fetcher
|
||||||
|
@ -409,9 +412,27 @@ class _FitLoop(_Loop):
|
||||||
self._data_fetcher = None
|
self._data_fetcher = None
|
||||||
self.epoch_loop.teardown()
|
self.epoch_loop.teardown()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_save_checkpoint(self) -> Dict:
|
||||||
|
state_dict = super().on_save_checkpoint()
|
||||||
|
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
|
||||||
|
state_dict["combined_loader"] = loader_states
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||||
|
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
|
||||||
|
super().on_load_checkpoint(state_dict)
|
||||||
|
|
||||||
def _should_accumulate(self) -> bool:
|
def _should_accumulate(self) -> bool:
|
||||||
"""Whether the gradients should be accumulated."""
|
"""Whether the gradients should be accumulated."""
|
||||||
return self.epoch_loop._should_accumulate()
|
return self.epoch_loop._should_accumulate()
|
||||||
|
|
||||||
def _iteration_based_training(self) -> bool:
|
def _iteration_based_training(self) -> bool:
|
||||||
return self.trainer.max_steps != -1
|
return self.trainer.max_steps != -1
|
||||||
|
|
||||||
|
def _load_combined_loader_states(self) -> None:
|
||||||
|
if not self.restarting or not self._combined_loader_states_to_load or self._combined_loader is None:
|
||||||
|
return
|
||||||
|
self._combined_loader._load_state_dicts(self._combined_loader_states_to_load)
|
||||||
|
self._combined_loader_states_to_load = [] # release memory
|
||||||
|
|
|
@ -19,6 +19,7 @@ from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDat
|
||||||
from typing_extensions import Self, TypedDict, override
|
from typing_extensions import Self, TypedDict, override
|
||||||
|
|
||||||
from lightning.fabric.utilities.data import sized_len
|
from lightning.fabric.utilities.data import sized_len
|
||||||
|
from lightning.fabric.utilities.types import _Stateful
|
||||||
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten
|
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten
|
||||||
|
|
||||||
_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx
|
_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx
|
||||||
|
@ -374,6 +375,24 @@ class CombinedLoader(Iterable):
|
||||||
fn = _SUPPORTED_MODES[self._mode]["fn"]
|
fn = _SUPPORTED_MODES[self._mode]["fn"]
|
||||||
return fn(lengths)
|
return fn(lengths)
|
||||||
|
|
||||||
|
def _state_dicts(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Returns the list of state dicts for iterables in `self.flattened` that are stateful."""
|
||||||
|
return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)]
|
||||||
|
|
||||||
|
def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Loads the state dicts for iterables in `self.flattened` that are stateful."""
|
||||||
|
if not states:
|
||||||
|
return
|
||||||
|
stateful_loaders = [loader for loader in self.flattened if isinstance(loader, _Stateful)]
|
||||||
|
if len(stateful_loaders) != len(states):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The CombinedLoader has {len(stateful_loaders)} stateful loaders, but found {len(states)} states"
|
||||||
|
" in the checkpoint. Please make sure you define the same dataloaders that were used when saving"
|
||||||
|
" the checkpoint."
|
||||||
|
)
|
||||||
|
for loader, state_dict in zip(stateful_loaders, states):
|
||||||
|
loader.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
|
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
|
||||||
if hasattr(dataloader, "_iterator"):
|
if hasattr(dataloader, "_iterator"):
|
||||||
|
|
|
@ -24,6 +24,7 @@ from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCh
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||||
from lightning.pytorch.loops import _Loop
|
from lightning.pytorch.loops import _Loop
|
||||||
from lightning.pytorch.loops.progress import _BaseProgress
|
from lightning.pytorch.loops.progress import _BaseProgress
|
||||||
|
from lightning.pytorch.utilities import CombinedLoader
|
||||||
from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter
|
from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter
|
||||||
|
|
||||||
from tests_pytorch.helpers.runif import RunIf
|
from tests_pytorch.helpers.runif import RunIf
|
||||||
|
@ -882,3 +883,94 @@ def test_validation_during_gradient_accumulation_window(tmp_path):
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
assert model.ran_assert
|
assert model.ran_assert
|
||||||
|
|
||||||
|
|
||||||
|
class NotStatefulIterable:
|
||||||
|
def __init__(self, start=0):
|
||||||
|
self.index = start
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(self.index, len(self)):
|
||||||
|
self.index = i
|
||||||
|
yield self.index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 10
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulIterable(NotStatefulIterable):
|
||||||
|
def state_dict(self):
|
||||||
|
return {"index": self.index}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.index = state_dict["index"] + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("train_dataloader_factory", "has_state", "batches_before", "batches_after"),
|
||||||
|
[
|
||||||
|
# No dataloader
|
||||||
|
(lambda: [], False, [], []),
|
||||||
|
# Single stateful DataLoader
|
||||||
|
(lambda: StatefulIterable(), True, [0, 1], [2, 3]),
|
||||||
|
# Single, not stateful DataLoader
|
||||||
|
(lambda: CombinedLoader(NotStatefulIterable()), False, [0, 1], [0, 1]),
|
||||||
|
# Single stateful DataLoader
|
||||||
|
(lambda: CombinedLoader(StatefulIterable()), True, [0, 1], [2, 3]),
|
||||||
|
# Multiple stateful DataLoaders
|
||||||
|
(lambda: CombinedLoader([StatefulIterable(3), StatefulIterable(1)]), True, [[3, 1], [4, 2]], [[5, 3], [6, 4]]),
|
||||||
|
# Mix of stateful and not stateful DataLoaders
|
||||||
|
(
|
||||||
|
lambda: CombinedLoader([NotStatefulIterable(3), StatefulIterable(1), NotStatefulIterable(2)]),
|
||||||
|
True,
|
||||||
|
[[3, 1, 2], [4, 2, 3]],
|
||||||
|
[[3, 3, 2], [4, 4, 3]],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_fit_loop_save_and_restore_dataloaders(
|
||||||
|
train_dataloader_factory, has_state, batches_before, batches_after, tmp_path
|
||||||
|
):
|
||||||
|
"""Test that the CheckpointConnector saves the state of stateful dataloaders."""
|
||||||
|
|
||||||
|
class DummyModel(BoringModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seen_data = []
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
self.seen_data.append(batch)
|
||||||
|
print(batch)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return train_dataloader_factory()
|
||||||
|
|
||||||
|
trainer_kwargs = {
|
||||||
|
"default_root_dir": tmp_path,
|
||||||
|
"accelerator": "cpu",
|
||||||
|
"enable_checkpointing": False,
|
||||||
|
"enable_model_summary": False,
|
||||||
|
"enable_progress_bar": False,
|
||||||
|
"logger": False,
|
||||||
|
"num_sanity_val_steps": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Train for 2 steps
|
||||||
|
model = DummyModel()
|
||||||
|
trainer = Trainer(**trainer_kwargs, max_steps=2)
|
||||||
|
trainer.fit(model)
|
||||||
|
assert model.seen_data == batches_before
|
||||||
|
|
||||||
|
# Save a checkpoint
|
||||||
|
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
|
||||||
|
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
|
||||||
|
if has_state:
|
||||||
|
assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
|
||||||
|
else:
|
||||||
|
assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]
|
||||||
|
|
||||||
|
# Restore training from step 2 and continue 2 more steps
|
||||||
|
model = DummyModel()
|
||||||
|
trainer = Trainer(**trainer_kwargs, max_steps=4)
|
||||||
|
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
|
||||||
|
assert model.seen_data == batches_after
|
||||||
|
|
|
@ -14,9 +14,11 @@
|
||||||
import math
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, NamedTuple, Sequence, get_args
|
from typing import Any, NamedTuple, Sequence, get_args
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from lightning.fabric.utilities.types import _Stateful
|
||||||
from lightning.pytorch import Trainer
|
from lightning.pytorch import Trainer
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||||
from lightning.pytorch.utilities.combined_loader import (
|
from lightning.pytorch.utilities.combined_loader import (
|
||||||
|
@ -602,3 +604,57 @@ def test_combined_loader_can_be_pickled():
|
||||||
|
|
||||||
# no error
|
# no error
|
||||||
pickle.dumps(cl)
|
pickle.dumps(cl)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dicts():
|
||||||
|
state1, state2, state3 = Mock(), Mock(), Mock()
|
||||||
|
stateful1 = Mock(spec=_Stateful, state_dict=Mock(return_value=state1))
|
||||||
|
stateful2 = Mock(spec=_Stateful, state_dict=Mock(return_value=state2))
|
||||||
|
stateful3 = Mock(spec=_Stateful, state_dict=Mock(return_value=state3))
|
||||||
|
|
||||||
|
cl = CombinedLoader([])
|
||||||
|
assert cl._state_dicts() == []
|
||||||
|
cl = CombinedLoader([range(2)])
|
||||||
|
assert cl._state_dicts() == []
|
||||||
|
cl = CombinedLoader([stateful1])
|
||||||
|
assert cl._state_dicts() == [state1]
|
||||||
|
cl = CombinedLoader([range(2), stateful1])
|
||||||
|
assert cl._state_dicts() == [state1]
|
||||||
|
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
|
||||||
|
assert cl._state_dicts() == [state1, state2]
|
||||||
|
cl = CombinedLoader({"a": [range(2), stateful1], "b": [stateful2], "c": stateful3})
|
||||||
|
assert cl._state_dicts() == [state1, state2, state3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_dicts():
|
||||||
|
stateful1 = Mock(spec=_Stateful)
|
||||||
|
stateful2 = Mock(spec=_Stateful)
|
||||||
|
state1 = Mock()
|
||||||
|
state2 = Mock()
|
||||||
|
|
||||||
|
# 0 stateful loaders, 1 state to load
|
||||||
|
cl = CombinedLoader([range(2), range(3)])
|
||||||
|
with pytest.raises(RuntimeError, match="has 0 stateful loaders, but found 1 states"):
|
||||||
|
cl._load_state_dicts([{"state": 0}])
|
||||||
|
|
||||||
|
# 1 stateful loader, 0 states to load
|
||||||
|
cl = CombinedLoader([stateful1, range(3)])
|
||||||
|
cl._load_state_dicts([])
|
||||||
|
stateful1.load_state_dict.assert_not_called()
|
||||||
|
|
||||||
|
# 1 stateful loader, 1 state to load
|
||||||
|
cl = CombinedLoader([range(2), stateful1, range(3)])
|
||||||
|
cl._load_state_dicts([state1])
|
||||||
|
stateful1.load_state_dict.assert_called_with(state1)
|
||||||
|
stateful1.reset_mock()
|
||||||
|
|
||||||
|
# 1 stateful loader, 2 states to load
|
||||||
|
cl = CombinedLoader([range(2), stateful1, range(3)])
|
||||||
|
with pytest.raises(RuntimeError, match="has 1 stateful loaders, but found 2 states"):
|
||||||
|
cl._load_state_dicts([state1, state2])
|
||||||
|
|
||||||
|
# 2 stateful loaders, 2 states to load
|
||||||
|
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
|
||||||
|
cl._load_state_dicts([state1, state2])
|
||||||
|
stateful1.load_state_dict.assert_called_with(state1)
|
||||||
|
stateful2.load_state_dict.assert_called_with(state2)
|
||||||
|
|
Loading…
Reference in New Issue