Fabric checkpointing 2/n: DeepSpeed implementation (#16452)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
4a802e00a8
commit
7603dd09cb
|
@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
-
|
||||
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -17,8 +17,9 @@ import logging
|
|||
import os
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
@ -31,7 +32,7 @@ from lightning_fabric.plugins.precision import Precision
|
|||
from lightning_fabric.strategies.ddp import DDPStrategy
|
||||
from lightning_fabric.strategies.strategy import _Sharded
|
||||
from lightning_fabric.utilities.distributed import log
|
||||
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from lightning_fabric.utilities.seed import reset_seed
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
|
||||
|
@ -365,24 +366,124 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
def save_checkpoint(
|
||||
self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
"""Save model, optimizer, and other state in a checkpoint directory.
|
||||
|
||||
Args:
|
||||
path: A path to where the files should be saved
|
||||
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
|
||||
state-dict will be retrieved and converted automatically.
|
||||
storage_options: Unused by this strategy, since it doesn't use a ``CheckpointIO`` plugin.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
If the unused ``storage_options`` gets passed.
|
||||
ValueError:
|
||||
When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple
|
||||
:class:`deepspeed.DeepSpeedEngine` objects were found.
|
||||
"""
|
||||
if storage_options is not None:
|
||||
raise TypeError(
|
||||
"`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
|
||||
" `DeepSpeedStrategy` does not use the `CheckpointIO`."
|
||||
)
|
||||
|
||||
engines = _get_deepspeed_engines_from_state(state)
|
||||
if len(engines) == 0:
|
||||
raise ValueError(
|
||||
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
|
||||
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
|
||||
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
|
||||
)
|
||||
elif len(engines) > 1:
|
||||
raise ValueError(
|
||||
"Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
|
||||
" currently limited to a single model per checkpoint. To save multiple models, call the"
|
||||
" save method for each model separately with a different path."
|
||||
)
|
||||
engine = engines[0]
|
||||
|
||||
# broadcast the path from rank 0 to ensure all the states are saved in a common path
|
||||
path = self.broadcast(path)
|
||||
|
||||
# split the checkpoint into two parts:
|
||||
# 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
|
||||
# 2) the rest of the user's state, which in deepspeed is called `client state`
|
||||
excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
|
||||
state = {k: v for k, v in state.items() if v not in excluded_objects}
|
||||
_validate_state_keys(state)
|
||||
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
|
||||
state = self._convert_stateful_objects_in_state(state)
|
||||
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
|
||||
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
|
||||
|
||||
def load_checkpoint(
|
||||
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
"""Load the contents from a checkpoint and restore the state of the given objects.
|
||||
|
||||
def load_optimizer_state_dict(
|
||||
self, optimizers: Union[Optimizer, Iterable[Optimizer]], checkpoint: Mapping[str, Any]
|
||||
) -> None:
|
||||
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
|
||||
pass
|
||||
Args:
|
||||
path: A path to where the file is located
|
||||
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
|
||||
This should contain exactly one model, and the model must already be set up by DeepSpeed.
|
||||
|
||||
def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) -> None:
|
||||
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
|
||||
Returns:
|
||||
Dictionary with the state inside DeepSpeed's engine
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If no state is provided, when no :class:`deepspeed.DeepSpeedEngine` objects were found in the
|
||||
state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found.
|
||||
RuntimeError:
|
||||
If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is
|
||||
not in the expected DeepSpeed format.
|
||||
"""
|
||||
if self.load_full_weights and self.zero_stage_3:
|
||||
self.module_to_device(module)
|
||||
self._restore_zero_state(module, checkpoint)
|
||||
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
|
||||
# a consolidated checkpoint
|
||||
path = self.broadcast(path)
|
||||
return super().load_checkpoint(path=path, state=state)
|
||||
|
||||
if not state:
|
||||
raise ValueError(
|
||||
f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
|
||||
f" a model instance to reload is required. Pass it in like so:"
|
||||
" DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
|
||||
)
|
||||
|
||||
engines = _get_deepspeed_engines_from_state(state)
|
||||
if len(engines) == 0:
|
||||
raise ValueError(
|
||||
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
|
||||
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
|
||||
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
|
||||
)
|
||||
elif len(engines) > 1:
|
||||
raise ValueError(
|
||||
"Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
|
||||
" with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
|
||||
" states, call the load method for each model checkpoint separately."
|
||||
)
|
||||
engine = engines[0]
|
||||
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
_, client_state = engine.load_checkpoint(
|
||||
path,
|
||||
tag="checkpoint",
|
||||
load_optimizer_states=optimzer_state_requested,
|
||||
load_lr_scheduler_states=False,
|
||||
load_module_strict=True, # TODO(fabric): make strict loading configurable
|
||||
)
|
||||
if client_state is None:
|
||||
raise RuntimeError(
|
||||
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
|
||||
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
|
||||
)
|
||||
for k, v in client_state.copy().items():
|
||||
if k not in state:
|
||||
continue
|
||||
state[k] = client_state.pop(k)
|
||||
return client_state
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
|
@ -645,3 +746,38 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
config = json.load(f)
|
||||
assert isinstance(config, dict) or config is None
|
||||
return config
|
||||
|
||||
|
||||
def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]:
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
|
||||
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
|
||||
return engines
|
||||
|
||||
|
||||
def _validate_state_keys(state: Dict[str, Any]) -> None:
|
||||
# DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
|
||||
# colliding keys from the user. We explicitly check it here:
|
||||
deepspeed_internal_keys = {
|
||||
"module",
|
||||
"buffer_names",
|
||||
"optimizer",
|
||||
"param_shapes",
|
||||
"lr_scheduler",
|
||||
"sparse_tensor_module_names",
|
||||
"skipped_steps",
|
||||
"global_steps",
|
||||
"global_samples",
|
||||
"dp_world_size",
|
||||
"mp_world_size",
|
||||
"ds_config",
|
||||
"ds_version",
|
||||
}
|
||||
colliding_keys = deepspeed_internal_keys.intersection(state.keys())
|
||||
if colliding_keys:
|
||||
rank_zero_warn(
|
||||
"Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
|
||||
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
|
||||
+ ", ".join(colliding_keys)
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ from unittest.mock import ANY, Mock
|
|||
import pytest
|
||||
import torch
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lightning_fabric.accelerators import CPUAccelerator
|
||||
from lightning_fabric.strategies import DeepSpeedStrategy
|
||||
|
@ -151,3 +152,172 @@ def test_deepspeed_requires_joint_setup():
|
|||
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
|
||||
):
|
||||
strategy.setup_optimizer(Mock())
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
|
||||
"""Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
|
||||
strategy = DeepSpeedStrategy()
|
||||
with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
|
||||
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
|
||||
"""Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
|
||||
# missing DeepSpeedEngine
|
||||
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
||||
strategy.save_checkpoint(path=tmp_path, state={})
|
||||
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
||||
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
||||
|
||||
# multiple DeepSpeedEngine
|
||||
model1 = Mock(spec=torch.nn.Module)
|
||||
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
||||
model2 = Mock(spec=torch.nn.Module)
|
||||
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
||||
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
|
||||
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
|
||||
"""Test that the DeepSpeed engine and optimizer get separated from the client state."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
|
||||
# Model only
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=None)
|
||||
model.modules.return_value = [model]
|
||||
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
|
||||
# the client_state should not contain any deepspeed engine or deepspeed optimizer
|
||||
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
|
||||
|
||||
# Model and optimizer
|
||||
optimizer = Mock()
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
||||
model.modules.return_value = [model]
|
||||
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
|
||||
# the client_state should not contain any deepspeed engine or deepspeed optimizer
|
||||
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
|
||||
"""Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
optimizer = Mock()
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
||||
model.modules.return_value = [model]
|
||||
# `mp_world_size` is an internal key
|
||||
with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
|
||||
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_load_checkpoint_no_state(tmp_path):
|
||||
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
|
||||
strategy = DeepSpeedStrategy()
|
||||
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
|
||||
strategy.load_checkpoint(path=tmp_path, state=None)
|
||||
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
|
||||
strategy.load_checkpoint(path=tmp_path, state={})
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
|
||||
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
|
||||
# missing DeepSpeedEngine
|
||||
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
||||
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
|
||||
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
||||
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
||||
|
||||
# multiple DeepSpeedEngine
|
||||
model1 = Mock(spec=torch.nn.Module)
|
||||
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
||||
model2 = Mock(spec=torch.nn.Module)
|
||||
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
||||
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
|
||||
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
|
||||
"""Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
optimizer = Mock()
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
||||
model.modules.return_value = [model]
|
||||
|
||||
# If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
|
||||
# returns None from its function call
|
||||
model.load_checkpoint.return_value = [None, None]
|
||||
|
||||
# Check for our custom user error
|
||||
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
|
||||
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
|
||||
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
optimizer = Mock()
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
||||
model.modules.return_value = [model]
|
||||
|
||||
# the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
|
||||
loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
|
||||
model.load_checkpoint.return_value = [None, loaded_client_state]
|
||||
|
||||
state = {"model": model, "user_data": {"iteration": 0}}
|
||||
metadata = strategy.load_checkpoint(path=tmp_path, state=state)
|
||||
|
||||
# the user's state gets updated with the loaded value
|
||||
assert state == {"model": model, "user_data": {"iteration": 5}}
|
||||
# additional metadata gets separated from client state
|
||||
assert metadata == {"deepspeed_metadata": "data"}
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
|
||||
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
|
||||
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
strategy = DeepSpeedStrategy()
|
||||
optimizer = Mock(spec=Optimizer)
|
||||
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
||||
model.modules.return_value = [model]
|
||||
|
||||
# required, otherwise mock cannot be unpacked
|
||||
model.load_checkpoint.return_value = [None, {}]
|
||||
|
||||
state = {"model": model}
|
||||
if optimzer_state_requested:
|
||||
state["optimizer"] = optimizer
|
||||
|
||||
strategy.load_checkpoint(path=tmp_path, state=state)
|
||||
model.load_checkpoint.assert_called_with(
|
||||
tmp_path,
|
||||
tag="checkpoint",
|
||||
load_optimizer_states=optimzer_state_requested,
|
||||
load_lr_scheduler_states=False,
|
||||
load_module_strict=True,
|
||||
)
|
||||
|
|
|
@ -241,7 +241,7 @@ class ModelParallelClassification(BoringFabric):
|
|||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
||||
def test_deepspeed_multigpu_stage_3(tmpdir):
|
||||
def test_deepspeed_multigpu_stage_3():
|
||||
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
|
||||
fabric = ModelParallelClassification(
|
||||
strategy=DeepSpeedStrategy(stage=3),
|
||||
|
@ -280,7 +280,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
|
|||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
||||
def test_deepspeed_specific_gpu_device_index(tmpdir):
|
||||
def test_deepspeed_specific_gpu_device_index():
|
||||
"""Test that the DeepSpeed strategy can run on specific device indices."""
|
||||
|
||||
class RunFabric(BoringFabric):
|
||||
|
@ -296,7 +296,7 @@ def test_deepspeed_specific_gpu_device_index(tmpdir):
|
|||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
|
||||
def test_deepspeed_with_bfloat16_precision(tmpdir):
|
||||
def test_deepspeed_with_bfloat16_precision():
|
||||
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
|
||||
|
||||
class Model(nn.Module):
|
||||
|
@ -323,3 +323,88 @@ def test_deepspeed_with_bfloat16_precision(tmpdir):
|
|||
assert fabric._strategy.precision.precision == "bf16"
|
||||
assert fabric._strategy.config["zero_optimization"]["stage"] == 3
|
||||
fabric.run()
|
||||
|
||||
|
||||
def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
|
||||
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
|
||||
weights in float32 precision."""
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
assert isinstance(fabric.strategy, DeepSpeedStrategy)
|
||||
|
||||
# carry out the check only on rank 0
|
||||
if fabric.is_global_zero:
|
||||
if fabric.strategy.config["zero_optimization"]["stage"] in (2, 3):
|
||||
single_ckpt_path = checkpoint_path / "single_model.pt"
|
||||
# the tag is hardcoded in DeepSpeedStrategy
|
||||
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
|
||||
state_dict = torch.load(single_ckpt_path)
|
||||
else:
|
||||
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
|
||||
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
|
||||
state_dict = torch.load(single_ckpt_path)["module"]
|
||||
|
||||
model = model.cpu()
|
||||
|
||||
# assert model parameters are identical after loading
|
||||
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
|
||||
# perform the equality check in the same precision
|
||||
saved_model_param = saved_model_param.cpu().to(orig_param.dtype)
|
||||
assert torch.equal(orig_param, saved_model_param)
|
||||
|
||||
fabric.barrier()
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
|
||||
@pytest.mark.parametrize("stage", [1, 2, 3])
|
||||
def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
|
||||
"""Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully."""
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
|
||||
fabric.launch()
|
||||
|
||||
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
|
||||
|
||||
with fabric.sharded_model():
|
||||
model = BoringModel()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
model, optimizer = fabric.setup(model, optimizer)
|
||||
assert isinstance(model._forward_module, DeepSpeedEngine)
|
||||
|
||||
# TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16
|
||||
assert model.dtype == torch.float32
|
||||
assert next(model.parameters()).dtype == torch.bfloat16
|
||||
|
||||
# dummy training step
|
||||
output = model(torch.randn(1, 32).to(fabric.device))
|
||||
loss = output.sum()
|
||||
fabric.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
||||
fabric.save(checkpoint_path, state)
|
||||
|
||||
fabric.barrier()
|
||||
|
||||
# re-init all objects and resume
|
||||
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
|
||||
fabric.launch()
|
||||
with fabric.sharded_model():
|
||||
model = BoringModel()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
model, optimizer = fabric.setup(model, optimizer)
|
||||
state = {"model": model, "optimizer": optimizer, "steps": 0}
|
||||
|
||||
metadata = fabric.load(checkpoint_path, state)
|
||||
fabric.barrier()
|
||||
|
||||
# check user data in state reloaded
|
||||
assert state["steps"] == 1
|
||||
# the remainder of the deepspeed checkpoint contains metadata
|
||||
assert "ds_version" in metadata
|
||||
|
||||
_assert_saved_model_is_equal(fabric, model, checkpoint_path)
|
||||
|
|
Loading…
Reference in New Issue