Added configurable strict loading for Fabric strategies (#17645)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: bas <bas.krahmer@talentflyxpert.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
9c07cb397c
commit
420eb6f248
|
@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
|
||||
|
||||
|
||||
- Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645))
|
||||
|
||||
|
||||
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
|
||||
|
||||
|
||||
|
|
|
@ -651,7 +651,10 @@ class Fabric:
|
|||
self.barrier()
|
||||
|
||||
def load(
|
||||
self, path: Union[str, Path], state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None,
|
||||
strict: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
|
||||
|
||||
|
@ -662,13 +665,14 @@ class Fabric:
|
|||
path: A path to where the file is located
|
||||
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
|
||||
If no state is given, then the checkpoint will be returned in full.
|
||||
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
|
||||
|
||||
Returns:
|
||||
The remaining items that were not restored into the given state dictionary. If no state dictionary is
|
||||
given, the full checkpoint will be returned.
|
||||
"""
|
||||
unwrapped_state = _unwrap_objects(state)
|
||||
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state)
|
||||
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
|
||||
self.barrier()
|
||||
if state is not None:
|
||||
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
|
||||
|
|
|
@ -428,7 +428,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
|
||||
|
||||
def load_checkpoint(
|
||||
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
|
||||
self,
|
||||
path: _PATH,
|
||||
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
|
||||
strict: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load the contents from a checkpoint and restore the state of the given objects.
|
||||
|
||||
|
@ -436,6 +439,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
path: A path to where the file is located
|
||||
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
|
||||
This should contain exactly one model, and the model must already be set up by DeepSpeed.
|
||||
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
|
||||
|
||||
Returns:
|
||||
Dictionary with the state inside DeepSpeed's engine
|
||||
|
@ -452,7 +456,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
|
||||
# a consolidated checkpoint
|
||||
path = self.broadcast(path)
|
||||
return super().load_checkpoint(path=path, state=state)
|
||||
return super().load_checkpoint(path=path, state=state, strict=strict)
|
||||
|
||||
if not state:
|
||||
raise ValueError(
|
||||
|
@ -483,13 +487,14 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
tag="checkpoint",
|
||||
load_optimizer_states=optimzer_state_requested,
|
||||
load_lr_scheduler_states=False,
|
||||
load_module_strict=True, # TODO(fabric): make strict loading configurable
|
||||
load_module_strict=strict,
|
||||
)
|
||||
if client_state is None:
|
||||
raise RuntimeError(
|
||||
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
|
||||
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
|
||||
)
|
||||
|
||||
for k in client_state.copy():
|
||||
if k not in state:
|
||||
continue
|
||||
|
|
|
@ -31,7 +31,12 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
|
|||
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
|
||||
from lightning.fabric.strategies.parallel import ParallelStrategy
|
||||
from lightning.fabric.strategies.registry import _StrategyRegistry
|
||||
from lightning.fabric.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast
|
||||
from lightning.fabric.strategies.strategy import (
|
||||
_BackwardSyncControl,
|
||||
_Sharded,
|
||||
_validate_keys_for_strict_loading,
|
||||
TBroadcast,
|
||||
)
|
||||
from lightning.fabric.utilities.distributed import (
|
||||
_get_default_process_group_backend_for_device,
|
||||
_init_dist_connection,
|
||||
|
@ -417,7 +422,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
|
||||
|
||||
def load_checkpoint(
|
||||
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
|
||||
self,
|
||||
path: _PATH,
|
||||
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
|
||||
strict: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load the contents from a checkpoint and restore the state of the given objects.
|
||||
|
||||
|
@ -465,7 +473,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
with state_dict_ctx:
|
||||
module_state = {module_key: module.state_dict()}
|
||||
load_state_dict(module_state, reader)
|
||||
module.load_state_dict(module_state[module_key])
|
||||
module.load_state_dict(module_state[module_key], strict=strict)
|
||||
|
||||
# the optimizer states must be loaded separately
|
||||
for optim_key, optim in optimizers.items():
|
||||
|
@ -483,11 +491,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
|
||||
# Load metadata (anything not a module or optimizer)
|
||||
metadata = torch.load(path / _METADATA_FILENAME)
|
||||
for key, obj in state.items():
|
||||
if isinstance(obj, (FSDP, Optimizer)):
|
||||
continue
|
||||
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
|
||||
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
|
||||
for key in requested_metadata_keys:
|
||||
if key not in metadata:
|
||||
raise KeyError(f"'{key}' not found in the checkpoint.")
|
||||
continue
|
||||
state[key] = metadata.pop(key)
|
||||
|
||||
# return the remaining metadata that wasn't requested as part of `state`
|
||||
|
@ -504,14 +512,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
# There is currently no other way because `summon_full_params` does not support write-back from rank 0 only.
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
with FSDP.summon_full_params(module, writeback=True, rank0_only=False):
|
||||
module.load_state_dict(checkpoint.pop(module_key))
|
||||
module.load_state_dict(checkpoint.pop(module_key), strict=strict)
|
||||
|
||||
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
|
||||
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
|
||||
# Load metadata (anything not a module or optimizer)
|
||||
for key, obj in state.items():
|
||||
if isinstance(obj, (FSDP, Optimizer)):
|
||||
continue
|
||||
for key in requested_metadata_keys:
|
||||
if key not in checkpoint:
|
||||
raise KeyError(f"'{key}' not found in the checkpoint.")
|
||||
continue
|
||||
state[key] = checkpoint.pop(key)
|
||||
|
||||
# return the remaining metadata that wasn't requested as part of `state`
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -267,7 +267,10 @@ class Strategy(ABC):
|
|||
return optimizer.state_dict()
|
||||
|
||||
def load_checkpoint(
|
||||
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
|
||||
self,
|
||||
path: _PATH,
|
||||
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
|
||||
strict: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load the contents from a checkpoint and restore the state of the given objects.
|
||||
|
||||
|
@ -275,6 +278,7 @@ class Strategy(ABC):
|
|||
path: A path to where the file is located
|
||||
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
|
||||
If no state is given, then the checkpoint will be returned in full.
|
||||
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
|
||||
|
||||
Returns:
|
||||
The remaining items that were not restored into the given state dictionary. If no state dictionary is
|
||||
|
@ -285,20 +289,13 @@ class Strategy(ABC):
|
|||
if not state:
|
||||
return checkpoint
|
||||
|
||||
invalid_keys = [k for k in state if k not in checkpoint]
|
||||
if invalid_keys:
|
||||
# TODO(fabric): Make strict loading configurable to avoid this error if desired.
|
||||
raise KeyError(
|
||||
f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint."
|
||||
)
|
||||
|
||||
_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
|
||||
for name, obj in state.copy().items():
|
||||
if name not in checkpoint:
|
||||
continue
|
||||
if isinstance(obj, _Stateful):
|
||||
if isinstance(obj, Module):
|
||||
# TODO(fabric): Make strict loading configurable
|
||||
obj.load_state_dict(checkpoint.pop(name), strict=True)
|
||||
obj.load_state_dict(checkpoint.pop(name), strict=strict)
|
||||
else:
|
||||
obj.load_state_dict(checkpoint.pop(name))
|
||||
else:
|
||||
|
@ -397,3 +394,14 @@ class _Sharded(ABC):
|
|||
By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
def _validate_keys_for_strict_loading(
|
||||
requested_keys: Iterable[str], checkpoint_keys: Iterable[str], strict: bool
|
||||
) -> None:
|
||||
invalid_keys = [k for k in requested_keys if k not in checkpoint_keys]
|
||||
if strict and invalid_keys:
|
||||
raise KeyError(
|
||||
f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint."
|
||||
f" To disable strict loading, set `strict=False`."
|
||||
)
|
||||
|
|
|
@ -119,9 +119,14 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
|
|||
|
||||
# attempt to load a key not in the metadata checkpoint
|
||||
state = {"model": fabric.model, "coconut": 11}
|
||||
with pytest.raises(KeyError, match="'coconut' not found in the checkpoint."):
|
||||
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
|
||||
fabric.load(checkpoint_path, state)
|
||||
|
||||
# `strict=False` ignores the missing key
|
||||
state = {"model": fabric.model, "coconut": 11}
|
||||
fabric.load(checkpoint_path, state, strict=False)
|
||||
assert state["coconut"] == 11
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
||||
def test_fsdp_save_full_state_dict(tmp_path):
|
||||
|
|
|
@ -119,4 +119,32 @@ def test_load_checkpoint_strict_loading(tmp_path):
|
|||
load_checkpoint_mock = Mock(return_value=saved_state)
|
||||
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
|
||||
with pytest.raises(KeyError, match="contains a key 'c' that does not exist"):
|
||||
strategy.load_checkpoint(tmp_path, requested_state)
|
||||
strategy.load_checkpoint(tmp_path, requested_state, strict=True)
|
||||
|
||||
|
||||
def test_load_checkpoint_non_strict_loading(tmp_path):
|
||||
"""Test that no error is raised if `strict=False` and state is requested that does not exist in the
|
||||
checkpoint."""
|
||||
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
||||
|
||||
# objects with initial state
|
||||
saved_model = nn.Linear(2, 2)
|
||||
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
|
||||
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"}
|
||||
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state)
|
||||
|
||||
# same objects with different state
|
||||
model = nn.Linear(2, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
|
||||
state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"}
|
||||
assert not torch.equal(model.weight, saved_model.weight)
|
||||
assert optimizer.state_dict() != saved_optimizer.state_dict()
|
||||
|
||||
remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False)
|
||||
assert torch.equal(model.weight, saved_model.weight)
|
||||
assert optimizer.state_dict() == saved_optimizer.state_dict()
|
||||
assert state["int"] == saved_state["int"]
|
||||
assert "str" not in state
|
||||
assert "str" in remainder
|
||||
assert state["new"] == "not_present_in_saved_state"
|
||||
assert "new" not in remainder
|
||||
|
|
|
@ -946,7 +946,7 @@ def test_load_wrapped_objects(setup, tmp_path):
|
|||
|
||||
expected_remainder = {"extra": "data"}
|
||||
|
||||
def mocked_load_checkpoint(path, state):
|
||||
def mocked_load_checkpoint(path, state, strict):
|
||||
assert not isinstance(state["model"], _FabricModule)
|
||||
assert not isinstance(state["optimizer"], _FabricOptimizer)
|
||||
state.update({"int": 5, "dict": {"x": 1}})
|
||||
|
|
Loading…
Reference in New Issue