Fabric checkpointing 2/n: DeepSpeed implementation (#16452)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2023-01-24 18:53:26 +01:00 committed by GitHub
parent 4a802e00a8
commit 7603dd09cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 408 additions and 17 deletions

View File

@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
-
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))
### Changed

View File

@ -17,8 +17,9 @@ import logging
import os
import platform
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
import torch
from lightning_utilities.core.imports import RequirementCache
@ -31,7 +32,7 @@ from lightning_fabric.plugins.precision import Precision
from lightning_fabric.strategies.ddp import DDPStrategy
from lightning_fabric.strategies.strategy import _Sharded
from lightning_fabric.utilities.distributed import log
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning_fabric.utilities.seed import reset_seed
from lightning_fabric.utilities.types import _PATH
@ -365,24 +366,124 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
def save_checkpoint(
self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None
) -> None:
raise NotImplementedError
"""Save model, optimizer, and other state in a checkpoint directory.
Args:
path: A path to where the files should be saved
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
state-dict will be retrieved and converted automatically.
storage_options: Unused by this strategy, since it doesn't use a ``CheckpointIO`` plugin.
Raises:
TypeError:
If the unused ``storage_options`` gets passed.
ValueError:
When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple
:class:`deepspeed.DeepSpeedEngine` objects were found.
"""
if storage_options is not None:
raise TypeError(
"`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
" `DeepSpeedStrategy` does not use the `CheckpointIO`."
)
engines = _get_deepspeed_engines_from_state(state)
if len(engines) == 0:
raise ValueError(
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
)
elif len(engines) > 1:
raise ValueError(
"Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
" currently limited to a single model per checkpoint. To save multiple models, call the"
" save method for each model separately with a different path."
)
engine = engines[0]
# broadcast the path from rank 0 to ensure all the states are saved in a common path
path = self.broadcast(path)
# split the checkpoint into two parts:
# 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
# 2) the rest of the user's state, which in deepspeed is called `client state`
excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
state = {k: v for k, v in state.items() if v not in excluded_objects}
_validate_state_keys(state)
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
state = self._convert_stateful_objects_in_state(state)
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
def load_checkpoint(
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
) -> Dict[str, Any]:
raise NotImplementedError
"""Load the contents from a checkpoint and restore the state of the given objects.
def load_optimizer_state_dict(
self, optimizers: Union[Optimizer, Iterable[Optimizer]], checkpoint: Mapping[str, Any]
) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
pass
Args:
path: A path to where the file is located
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
This should contain exactly one model, and the model must already be set up by DeepSpeed.
def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
Returns:
Dictionary with the state inside DeepSpeed's engine
Raises:
ValueError:
If no state is provided, when no :class:`deepspeed.DeepSpeedEngine` objects were found in the
state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found.
RuntimeError:
If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is
not in the expected DeepSpeed format.
"""
if self.load_full_weights and self.zero_stage_3:
self.module_to_device(module)
self._restore_zero_state(module, checkpoint)
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
# a consolidated checkpoint
path = self.broadcast(path)
return super().load_checkpoint(path=path, state=state)
if not state:
raise ValueError(
f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
f" a model instance to reload is required. Pass it in like so:"
" DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
)
engines = _get_deepspeed_engines_from_state(state)
if len(engines) == 0:
raise ValueError(
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
)
elif len(engines) > 1:
raise ValueError(
"Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
" with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
" states, call the load method for each model checkpoint separately."
)
engine = engines[0]
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
torch.cuda.empty_cache()
_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
load_optimizer_states=optimzer_state_requested,
load_lr_scheduler_states=False,
load_module_strict=True, # TODO(fabric): make strict loading configurable
)
if client_state is None:
raise RuntimeError(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
)
for k, v in client_state.copy().items():
if k not in state:
continue
state[k] = client_state.pop(k)
return client_state
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
@ -645,3 +746,38 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
config = json.load(f)
assert isinstance(config, dict) or config is None
return config
def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]:
from deepspeed import DeepSpeedEngine
modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
return engines
def _validate_state_keys(state: Dict[str, Any]) -> None:
# DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
# colliding keys from the user. We explicitly check it here:
deepspeed_internal_keys = {
"module",
"buffer_names",
"optimizer",
"param_shapes",
"lr_scheduler",
"sparse_tensor_module_names",
"skipped_steps",
"global_steps",
"global_samples",
"dp_world_size",
"mp_world_size",
"ds_config",
"ds_version",
}
colliding_keys = deepspeed_internal_keys.intersection(state.keys())
if colliding_keys:
rank_zero_warn(
"Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
+ ", ".join(colliding_keys)
)

View File

@ -20,6 +20,7 @@ from unittest.mock import ANY, Mock
import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from torch.optim import Optimizer
from lightning_fabric.accelerators import CPUAccelerator
from lightning_fabric.strategies import DeepSpeedStrategy
@ -151,3 +152,172 @@ def test_deepspeed_requires_joint_setup():
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
):
strategy.setup_optimizer(Mock())
@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
"""Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
strategy = DeepSpeedStrategy()
with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
"""Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
# missing DeepSpeedEngine
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={})
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple DeepSpeedEngine
model1 = Mock(spec=torch.nn.Module)
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
model2 = Mock(spec=torch.nn.Module)
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
"""Test that the DeepSpeed engine and optimizer get separated from the client state."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
# Model only
model = Mock(spec=DeepSpeedEngine, optimizer=None)
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
# Model and optimizer
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
"""Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
# `mp_world_size` is an internal key
with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_no_state(tmp_path):
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
strategy = DeepSpeedStrategy()
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
strategy.load_checkpoint(path=tmp_path, state=None)
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
strategy.load_checkpoint(path=tmp_path, state={})
@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
# missing DeepSpeedEngine
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple DeepSpeedEngine
model1 = Mock(spec=torch.nn.Module)
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
model2 = Mock(spec=torch.nn.Module)
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
"""Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
# If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
# returns None from its function call
model.load_checkpoint.return_value = [None, None]
# Check for our custom user error
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
# the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
model.load_checkpoint.return_value = [None, loaded_client_state]
state = {"model": model, "user_data": {"iteration": 0}}
metadata = strategy.load_checkpoint(path=tmp_path, state=state)
# the user's state gets updated with the loaded value
assert state == {"model": model, "user_data": {"iteration": 5}}
# additional metadata gets separated from client state
assert metadata == {"deepspeed_metadata": "data"}
@RunIf(deepspeed=True)
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy()
optimizer = Mock(spec=Optimizer)
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
# required, otherwise mock cannot be unpacked
model.load_checkpoint.return_value = [None, {}]
state = {"model": model}
if optimzer_state_requested:
state["optimizer"] = optimizer
strategy.load_checkpoint(path=tmp_path, state=state)
model.load_checkpoint.assert_called_with(
tmp_path,
tag="checkpoint",
load_optimizer_states=optimzer_state_requested,
load_lr_scheduler_states=False,
load_module_strict=True,
)

View File

@ -241,7 +241,7 @@ class ModelParallelClassification(BoringFabric):
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3(tmpdir):
def test_deepspeed_multigpu_stage_3():
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
fabric = ModelParallelClassification(
strategy=DeepSpeedStrategy(stage=3),
@ -280,7 +280,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_specific_gpu_device_index(tmpdir):
def test_deepspeed_specific_gpu_device_index():
"""Test that the DeepSpeed strategy can run on specific device indices."""
class RunFabric(BoringFabric):
@ -296,7 +296,7 @@ def test_deepspeed_specific_gpu_device_index(tmpdir):
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
def test_deepspeed_with_bfloat16_precision(tmpdir):
def test_deepspeed_with_bfloat16_precision():
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
class Model(nn.Module):
@ -323,3 +323,88 @@ def test_deepspeed_with_bfloat16_precision(tmpdir):
assert fabric._strategy.precision.precision == "bf16"
assert fabric._strategy.config["zero_optimization"]["stage"] == 3
fabric.run()
def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
weights in float32 precision."""
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
assert isinstance(fabric.strategy, DeepSpeedStrategy)
# carry out the check only on rank 0
if fabric.is_global_zero:
if fabric.strategy.config["zero_optimization"]["stage"] in (2, 3):
single_ckpt_path = checkpoint_path / "single_model.pt"
# the tag is hardcoded in DeepSpeedStrategy
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
state_dict = torch.load(single_ckpt_path)
else:
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
state_dict = torch.load(single_ckpt_path)["module"]
model = model.cpu()
# assert model parameters are identical after loading
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
# perform the equality check in the same precision
saved_model_param = saved_model_param.cpu().to(orig_param.dtype)
assert torch.equal(orig_param, saved_model_param)
fabric.barrier()
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
@pytest.mark.parametrize("stage", [1, 2, 3])
def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
"""Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully."""
from deepspeed import DeepSpeedEngine
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
fabric.launch()
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
with fabric.sharded_model():
model = BoringModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
model, optimizer = fabric.setup(model, optimizer)
assert isinstance(model._forward_module, DeepSpeedEngine)
# TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16
assert model.dtype == torch.float32
assert next(model.parameters()).dtype == torch.bfloat16
# dummy training step
output = model(torch.randn(1, 32).to(fabric.device))
loss = output.sum()
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
state = {"model": model, "optimizer": optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
fabric.barrier()
# re-init all objects and resume
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
fabric.launch()
with fabric.sharded_model():
model = BoringModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "steps": 0}
metadata = fabric.load(checkpoint_path, state)
fabric.barrier()
# check user data in state reloaded
assert state["steps"] == 1
# the remainder of the deepspeed checkpoint contains metadata
assert "ds_version" in metadata
_assert_saved_model_is_equal(fabric, model, checkpoint_path)