diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 72361ad023..ad291fec16 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623)) +- Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645)) + + - Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index e17b3cffd2..91389b4dc0 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -651,7 +651,10 @@ class Fabric: self.barrier() def load( - self, path: Union[str, Path], state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None + self, + path: Union[str, Path], + state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None, + strict: bool = True, ) -> Dict[str, Any]: """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) @@ -662,13 +665,14 @@ class Fabric: path: A path to where the file is located state: A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full. + strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned. """ unwrapped_state = _unwrap_objects(state) - remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state) + remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict) self.barrier() if state is not None: # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index ca84c1e9cc..2ab61764a3 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -428,7 +428,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): engine.save_checkpoint(path, client_state=state, tag="checkpoint") def load_checkpoint( - self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None + self, + path: _PATH, + state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None, + strict: bool = True, ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -436,6 +439,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): path: A path to where the file is located state: A dictionary of objects whose state will be restored in-place from the checkpoint path. This should contain exactly one model, and the model must already be set up by DeepSpeed. + strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. Returns: Dictionary with the state inside DeepSpeed's engine @@ -452,7 +456,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from # a consolidated checkpoint path = self.broadcast(path) - return super().load_checkpoint(path=path, state=state) + return super().load_checkpoint(path=path, state=state, strict=strict) if not state: raise ValueError( @@ -483,13 +487,14 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): tag="checkpoint", load_optimizer_states=optimzer_state_requested, load_lr_scheduler_states=False, - load_module_strict=True, # TODO(fabric): make strict loading configurable + load_module_strict=strict, ) if client_state is None: raise RuntimeError( "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint" " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`." ) + for k in client_state.copy(): if k not in state: continue diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index ed3d89b878..c156cf2a49 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -31,7 +31,12 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.registry import _StrategyRegistry -from lightning.fabric.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast +from lightning.fabric.strategies.strategy import ( + _BackwardSyncControl, + _Sharded, + _validate_keys_for_strict_loading, + TBroadcast, +) from lightning.fabric.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -417,7 +422,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") def load_checkpoint( - self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None + self, + path: _PATH, + state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None, + strict: bool = True, ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -465,7 +473,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): with state_dict_ctx: module_state = {module_key: module.state_dict()} load_state_dict(module_state, reader) - module.load_state_dict(module_state[module_key]) + module.load_state_dict(module_state[module_key], strict=strict) # the optimizer states must be loaded separately for optim_key, optim in optimizers.items(): @@ -483,11 +491,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded): # Load metadata (anything not a module or optimizer) metadata = torch.load(path / _METADATA_FILENAME) - for key, obj in state.items(): - if isinstance(obj, (FSDP, Optimizer)): - continue + requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() + _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) + for key in requested_metadata_keys: if key not in metadata: - raise KeyError(f"'{key}' not found in the checkpoint.") + continue state[key] = metadata.pop(key) # return the remaining metadata that wasn't requested as part of `state` @@ -504,14 +512,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded): # There is currently no other way because `summon_full_params` does not support write-back from rank 0 only. checkpoint = torch.load(path, map_location="cpu") with FSDP.summon_full_params(module, writeback=True, rank0_only=False): - module.load_state_dict(checkpoint.pop(module_key)) + module.load_state_dict(checkpoint.pop(module_key), strict=strict) + requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() + _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) # Load metadata (anything not a module or optimizer) - for key, obj in state.items(): - if isinstance(obj, (FSDP, Optimizer)): - continue + for key in requested_metadata_keys: if key not in checkpoint: - raise KeyError(f"'{key}' not found in the checkpoint.") + continue state[key] = checkpoint.pop(key) # return the remaining metadata that wasn't requested as part of `state` diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index d8c910ad0f..526e1941db 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext -from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar, Union import torch from torch import Tensor @@ -267,7 +267,10 @@ class Strategy(ABC): return optimizer.state_dict() def load_checkpoint( - self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None + self, + path: _PATH, + state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None, + strict: bool = True, ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -275,6 +278,7 @@ class Strategy(ABC): path: A path to where the file is located state: A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full. + strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is @@ -285,20 +289,13 @@ class Strategy(ABC): if not state: return checkpoint - invalid_keys = [k for k in state if k not in checkpoint] - if invalid_keys: - # TODO(fabric): Make strict loading configurable to avoid this error if desired. - raise KeyError( - f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint." - ) - + _validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict) for name, obj in state.copy().items(): if name not in checkpoint: continue if isinstance(obj, _Stateful): if isinstance(obj, Module): - # TODO(fabric): Make strict loading configurable - obj.load_state_dict(checkpoint.pop(name), strict=True) + obj.load_state_dict(checkpoint.pop(name), strict=strict) else: obj.load_state_dict(checkpoint.pop(name)) else: @@ -397,3 +394,14 @@ class _Sharded(ABC): By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time. """ yield + + +def _validate_keys_for_strict_loading( + requested_keys: Iterable[str], checkpoint_keys: Iterable[str], strict: bool +) -> None: + invalid_keys = [k for k in requested_keys if k not in checkpoint_keys] + if strict and invalid_keys: + raise KeyError( + f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint." + f" To disable strict loading, set `strict=False`." + ) diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 8c3800bfd7..21a5a90fa8 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -119,9 +119,14 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision): # attempt to load a key not in the metadata checkpoint state = {"model": fabric.model, "coconut": 11} - with pytest.raises(KeyError, match="'coconut' not found in the checkpoint."): + with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"): fabric.load(checkpoint_path, state) + # `strict=False` ignores the missing key + state = {"model": fabric.model, "coconut": 11} + fabric.load(checkpoint_path, state, strict=False) + assert state["coconut"] == 11 + @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") def test_fsdp_save_full_state_dict(tmp_path): diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py index 2aaf42f9ed..e92a72dfca 100644 --- a/tests/tests_fabric/strategies/test_strategy.py +++ b/tests/tests_fabric/strategies/test_strategy.py @@ -119,4 +119,32 @@ def test_load_checkpoint_strict_loading(tmp_path): load_checkpoint_mock = Mock(return_value=saved_state) strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock with pytest.raises(KeyError, match="contains a key 'c' that does not exist"): - strategy.load_checkpoint(tmp_path, requested_state) + strategy.load_checkpoint(tmp_path, requested_state, strict=True) + + +def test_load_checkpoint_non_strict_loading(tmp_path): + """Test that no error is raised if `strict=False` and state is requested that does not exist in the + checkpoint.""" + strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class + + # objects with initial state + saved_model = nn.Linear(2, 2) + saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1) + saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"} + strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state) + + # same objects with different state + model = nn.Linear(2, 2) + optimizer = torch.optim.Adam(model.parameters(), lr=0.3) + state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"} + assert not torch.equal(model.weight, saved_model.weight) + assert optimizer.state_dict() != saved_optimizer.state_dict() + + remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False) + assert torch.equal(model.weight, saved_model.weight) + assert optimizer.state_dict() == saved_optimizer.state_dict() + assert state["int"] == saved_state["int"] + assert "str" not in state + assert "str" in remainder + assert state["new"] == "not_present_in_saved_state" + assert "new" not in remainder diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index fb3a0da5af..01ff49170c 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -946,7 +946,7 @@ def test_load_wrapped_objects(setup, tmp_path): expected_remainder = {"extra": "data"} - def mocked_load_checkpoint(path, state): + def mocked_load_checkpoint(path, state, strict): assert not isinstance(state["model"], _FabricModule) assert not isinstance(state["optimizer"], _FabricOptimizer) state.update({"int": 5, "dict": {"x": 1}})