Added configurable strict loading for Fabric strategies (#17645)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: bas <bas.krahmer@talentflyxpert.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Bas Krahmer 2023-06-07 00:26:13 +02:00 committed by GitHub
parent 9c07cb397c
commit 420eb6f248
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 92 additions and 31 deletions

View File

@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
- Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645))
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))

View File

@ -651,7 +651,10 @@ class Fabric:
self.barrier()
def load(
self, path: Union[str, Path], state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None
self,
path: Union[str, Path],
state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
@ -662,13 +665,14 @@ class Fabric:
path: A path to where the file is located
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.
"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
self.barrier()
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates

View File

@ -428,7 +428,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
def load_checkpoint(
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
self,
path: _PATH,
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.
@ -436,6 +439,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
path: A path to where the file is located
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
This should contain exactly one model, and the model must already be set up by DeepSpeed.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
Returns:
Dictionary with the state inside DeepSpeed's engine
@ -452,7 +456,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
# a consolidated checkpoint
path = self.broadcast(path)
return super().load_checkpoint(path=path, state=state)
return super().load_checkpoint(path=path, state=state, strict=strict)
if not state:
raise ValueError(
@ -483,13 +487,14 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
tag="checkpoint",
load_optimizer_states=optimzer_state_requested,
load_lr_scheduler_states=False,
load_module_strict=True, # TODO(fabric): make strict loading configurable
load_module_strict=strict,
)
if client_state is None:
raise RuntimeError(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
)
for k in client_state.copy():
if k not in state:
continue

View File

@ -31,7 +31,12 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast
from lightning.fabric.strategies.strategy import (
_BackwardSyncControl,
_Sharded,
_validate_keys_for_strict_loading,
TBroadcast,
)
from lightning.fabric.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
@ -417,7 +422,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
def load_checkpoint(
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
self,
path: _PATH,
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.
@ -465,7 +473,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
with state_dict_ctx:
module_state = {module_key: module.state_dict()}
load_state_dict(module_state, reader)
module.load_state_dict(module_state[module_key])
module.load_state_dict(module_state[module_key], strict=strict)
# the optimizer states must be loaded separately
for optim_key, optim in optimizers.items():
@ -483,11 +491,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
for key, obj in state.items():
if isinstance(obj, (FSDP, Optimizer)):
continue
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
for key in requested_metadata_keys:
if key not in metadata:
raise KeyError(f"'{key}' not found in the checkpoint.")
continue
state[key] = metadata.pop(key)
# return the remaining metadata that wasn't requested as part of `state`
@ -504,14 +512,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
# There is currently no other way because `summon_full_params` does not support write-back from rank 0 only.
checkpoint = torch.load(path, map_location="cpu")
with FSDP.summon_full_params(module, writeback=True, rank0_only=False):
module.load_state_dict(checkpoint.pop(module_key))
module.load_state_dict(checkpoint.pop(module_key), strict=strict)
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
# Load metadata (anything not a module or optimizer)
for key, obj in state.items():
if isinstance(obj, (FSDP, Optimizer)):
continue
for key in requested_metadata_keys:
if key not in checkpoint:
raise KeyError(f"'{key}' not found in the checkpoint.")
continue
state[key] = checkpoint.pop(key)
# return the remaining metadata that wasn't requested as part of `state`

View File

@ -14,7 +14,7 @@
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar, Union
import torch
from torch import Tensor
@ -267,7 +267,10 @@ class Strategy(ABC):
return optimizer.state_dict()
def load_checkpoint(
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
self,
path: _PATH,
state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.
@ -275,6 +278,7 @@ class Strategy(ABC):
path: A path to where the file is located
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
@ -285,20 +289,13 @@ class Strategy(ABC):
if not state:
return checkpoint
invalid_keys = [k for k in state if k not in checkpoint]
if invalid_keys:
# TODO(fabric): Make strict loading configurable to avoid this error if desired.
raise KeyError(
f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint."
)
_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
# TODO(fabric): Make strict loading configurable
obj.load_state_dict(checkpoint.pop(name), strict=True)
obj.load_state_dict(checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
else:
@ -397,3 +394,14 @@ class _Sharded(ABC):
By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
"""
yield
def _validate_keys_for_strict_loading(
requested_keys: Iterable[str], checkpoint_keys: Iterable[str], strict: bool
) -> None:
invalid_keys = [k for k in requested_keys if k not in checkpoint_keys]
if strict and invalid_keys:
raise KeyError(
f"The requested state contains a key '{invalid_keys[0]}' that does not exist in the loaded checkpoint."
f" To disable strict loading, set `strict=False`."
)

View File

@ -119,9 +119,14 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
# attempt to load a key not in the metadata checkpoint
state = {"model": fabric.model, "coconut": 11}
with pytest.raises(KeyError, match="'coconut' not found in the checkpoint."):
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
fabric.load(checkpoint_path, state)
# `strict=False` ignores the missing key
state = {"model": fabric.model, "coconut": 11}
fabric.load(checkpoint_path, state, strict=False)
assert state["coconut"] == 11
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_full_state_dict(tmp_path):

View File

@ -119,4 +119,32 @@ def test_load_checkpoint_strict_loading(tmp_path):
load_checkpoint_mock = Mock(return_value=saved_state)
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
with pytest.raises(KeyError, match="contains a key 'c' that does not exist"):
strategy.load_checkpoint(tmp_path, requested_state)
strategy.load_checkpoint(tmp_path, requested_state, strict=True)
def test_load_checkpoint_non_strict_loading(tmp_path):
"""Test that no error is raised if `strict=False` and state is requested that does not exist in the
checkpoint."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
# objects with initial state
saved_model = nn.Linear(2, 2)
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"}
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state)
# same objects with different state
model = nn.Linear(2, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"}
assert not torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() != saved_optimizer.state_dict()
remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False)
assert torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() == saved_optimizer.state_dict()
assert state["int"] == saved_state["int"]
assert "str" not in state
assert "str" in remainder
assert state["new"] == "not_present_in_saved_state"
assert "new" not in remainder

View File

@ -946,7 +946,7 @@ def test_load_wrapped_objects(setup, tmp_path):
expected_remainder = {"extra": "data"}
def mocked_load_checkpoint(path, state):
def mocked_load_checkpoint(path, state, strict):
assert not isinstance(state["model"], _FabricModule)
assert not isinstance(state["optimizer"], _FabricOptimizer)
state.update({"int": 5, "dict": {"x": 1}})