diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a255766aed..22985c7057 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381)) +- Added support for saving and loading stateful training DataLoaders ([#19361](https://github.com/Lightning-AI/lightning/pull/19361)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index e02fe70c4d..eb30e32757 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from typing_extensions import override @@ -94,6 +94,7 @@ class _FitLoop(_Loop): self._data_source = _DataLoaderSource(None, "train_dataloader") self._combined_loader: Optional[CombinedLoader] = None + self._combined_loader_states_to_load: List[Dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") @@ -255,6 +256,8 @@ class _FitLoop(_Loop): combined_loader.limits = limits + self._load_combined_loader_states() + self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING) self._data_fetcher.setup(combined_loader) iter(self._data_fetcher) # creates the iterator inside the fetcher @@ -409,9 +412,27 @@ class _FitLoop(_Loop): self._data_fetcher = None self.epoch_loop.teardown() + @override + def on_save_checkpoint(self) -> Dict: + state_dict = super().on_save_checkpoint() + if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()): + state_dict["combined_loader"] = loader_states + return state_dict + + @override + def on_load_checkpoint(self, state_dict: Dict) -> None: + self._combined_loader_states_to_load = state_dict.get("combined_loader", []) + super().on_load_checkpoint(state_dict) + def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() def _iteration_based_training(self) -> bool: return self.trainer.max_steps != -1 + + def _load_combined_loader_states(self) -> None: + if not self.restarting or not self._combined_loader_states_to_load or self._combined_loader is None: + return + self._combined_loader._load_state_dicts(self._combined_loader_states_to_load) + self._combined_loader_states_to_load = [] # release memory diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 49b4ccf587..9b0ceb0288 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -19,6 +19,7 @@ from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDat from typing_extensions import Self, TypedDict, override from lightning.fabric.utilities.data import sized_len +from lightning.fabric.utilities.types import _Stateful from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten _ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx @@ -374,6 +375,24 @@ class CombinedLoader(Iterable): fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) + def _state_dicts(self) -> List[Dict[str, Any]]: + """Returns the list of state dicts for iterables in `self.flattened` that are stateful.""" + return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)] + + def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None: + """Loads the state dicts for iterables in `self.flattened` that are stateful.""" + if not states: + return + stateful_loaders = [loader for loader in self.flattened if isinstance(loader, _Stateful)] + if len(stateful_loaders) != len(states): + raise RuntimeError( + f"The CombinedLoader has {len(stateful_loaders)} stateful loaders, but found {len(states)} states" + " in the checkpoint. Please make sure you define the same dataloaders that were used when saving" + " the checkpoint." + ) + for loader, state_dict in zip(stateful_loaders, states): + loader.load_state_dict(state_dict) + def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: if hasattr(dataloader, "_iterator"): diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 120d7949f6..b27d27d073 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -24,6 +24,7 @@ from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCh from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.progress import _BaseProgress +from lightning.pytorch.utilities import CombinedLoader from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter from tests_pytorch.helpers.runif import RunIf @@ -882,3 +883,94 @@ def test_validation_during_gradient_accumulation_window(tmp_path): ) trainer.fit(model) assert model.ran_assert + + +class NotStatefulIterable: + def __init__(self, start=0): + self.index = start + + def __iter__(self): + for i in range(self.index, len(self)): + self.index = i + yield self.index + + def __len__(self): + return 10 + + +class StatefulIterable(NotStatefulIterable): + def state_dict(self): + return {"index": self.index} + + def load_state_dict(self, state_dict): + self.index = state_dict["index"] + 1 + + +@pytest.mark.parametrize( + ("train_dataloader_factory", "has_state", "batches_before", "batches_after"), + [ + # No dataloader + (lambda: [], False, [], []), + # Single stateful DataLoader + (lambda: StatefulIterable(), True, [0, 1], [2, 3]), + # Single, not stateful DataLoader + (lambda: CombinedLoader(NotStatefulIterable()), False, [0, 1], [0, 1]), + # Single stateful DataLoader + (lambda: CombinedLoader(StatefulIterable()), True, [0, 1], [2, 3]), + # Multiple stateful DataLoaders + (lambda: CombinedLoader([StatefulIterable(3), StatefulIterable(1)]), True, [[3, 1], [4, 2]], [[5, 3], [6, 4]]), + # Mix of stateful and not stateful DataLoaders + ( + lambda: CombinedLoader([NotStatefulIterable(3), StatefulIterable(1), NotStatefulIterable(2)]), + True, + [[3, 1, 2], [4, 2, 3]], + [[3, 3, 2], [4, 4, 3]], + ), + ], +) +def test_fit_loop_save_and_restore_dataloaders( + train_dataloader_factory, has_state, batches_before, batches_after, tmp_path +): + """Test that the CheckpointConnector saves the state of stateful dataloaders.""" + + class DummyModel(BoringModel): + def __init__(self): + super().__init__() + self.seen_data = [] + + def training_step(self, batch, batch_idx): + self.seen_data.append(batch) + print(batch) + + def train_dataloader(self): + return train_dataloader_factory() + + trainer_kwargs = { + "default_root_dir": tmp_path, + "accelerator": "cpu", + "enable_checkpointing": False, + "enable_model_summary": False, + "enable_progress_bar": False, + "logger": False, + "num_sanity_val_steps": 0, + } + + # Train for 2 steps + model = DummyModel() + trainer = Trainer(**trainer_kwargs, max_steps=2) + trainer.fit(model) + assert model.seen_data == batches_before + + # Save a checkpoint + trainer.save_checkpoint(tmp_path / "checkpoint.ckpt") + checkpoint = torch.load(tmp_path / "checkpoint.ckpt") + if has_state: + assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"] + else: + assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"] + + # Restore training from step 2 and continue 2 more steps + model = DummyModel() + trainer = Trainer(**trainer_kwargs, max_steps=4) + trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) + assert model.seen_data == batches_after diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 23517d1f35..71ebc6e69a 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -14,9 +14,11 @@ import math import pickle from typing import Any, NamedTuple, Sequence, get_args +from unittest.mock import Mock import pytest import torch +from lightning.fabric.utilities.types import _Stateful from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.combined_loader import ( @@ -602,3 +604,57 @@ def test_combined_loader_can_be_pickled(): # no error pickle.dumps(cl) + + +def test_state_dicts(): + state1, state2, state3 = Mock(), Mock(), Mock() + stateful1 = Mock(spec=_Stateful, state_dict=Mock(return_value=state1)) + stateful2 = Mock(spec=_Stateful, state_dict=Mock(return_value=state2)) + stateful3 = Mock(spec=_Stateful, state_dict=Mock(return_value=state3)) + + cl = CombinedLoader([]) + assert cl._state_dicts() == [] + cl = CombinedLoader([range(2)]) + assert cl._state_dicts() == [] + cl = CombinedLoader([stateful1]) + assert cl._state_dicts() == [state1] + cl = CombinedLoader([range(2), stateful1]) + assert cl._state_dicts() == [state1] + cl = CombinedLoader([range(2), stateful1, range(3), stateful2]) + assert cl._state_dicts() == [state1, state2] + cl = CombinedLoader({"a": [range(2), stateful1], "b": [stateful2], "c": stateful3}) + assert cl._state_dicts() == [state1, state2, state3] + + +def test_load_state_dicts(): + stateful1 = Mock(spec=_Stateful) + stateful2 = Mock(spec=_Stateful) + state1 = Mock() + state2 = Mock() + + # 0 stateful loaders, 1 state to load + cl = CombinedLoader([range(2), range(3)]) + with pytest.raises(RuntimeError, match="has 0 stateful loaders, but found 1 states"): + cl._load_state_dicts([{"state": 0}]) + + # 1 stateful loader, 0 states to load + cl = CombinedLoader([stateful1, range(3)]) + cl._load_state_dicts([]) + stateful1.load_state_dict.assert_not_called() + + # 1 stateful loader, 1 state to load + cl = CombinedLoader([range(2), stateful1, range(3)]) + cl._load_state_dicts([state1]) + stateful1.load_state_dict.assert_called_with(state1) + stateful1.reset_mock() + + # 1 stateful loader, 2 states to load + cl = CombinedLoader([range(2), stateful1, range(3)]) + with pytest.raises(RuntimeError, match="has 1 stateful loaders, but found 2 states"): + cl._load_state_dicts([state1, state2]) + + # 2 stateful loaders, 2 states to load + cl = CombinedLoader([range(2), stateful1, range(3), stateful2]) + cl._load_state_dicts([state1, state2]) + stateful1.load_state_dict.assert_called_with(state1) + stateful2.load_state_dict.assert_called_with(state2)