324 lines
13 KiB
Python
324 lines
13 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import json
|
|
import os
|
|
from re import escape
|
|
from unittest import mock
|
|
from unittest.mock import ANY, Mock
|
|
|
|
import pytest
|
|
import torch
|
|
from tests_fabric.helpers.runif import RunIf
|
|
from torch.optim import Optimizer
|
|
|
|
from lightning.fabric.accelerators import CPUAccelerator
|
|
from lightning.fabric.strategies import DeepSpeedStrategy
|
|
|
|
|
|
@pytest.fixture
|
|
def deepspeed_config():
|
|
return {
|
|
"optimizer": {"type": "SGD", "params": {"lr": 3e-5}},
|
|
"scheduler": {
|
|
"type": "WarmupLR",
|
|
"params": {"last_batch_iteration": -1, "warmup_min_lr": 0, "warmup_max_lr": 3e-5, "warmup_num_steps": 100},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def deepspeed_zero_config(deepspeed_config):
|
|
return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}}
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_only_compatible_with_cuda():
|
|
"""Test that the DeepSpeed strategy raises an exception if an invalid accelerator is used."""
|
|
strategy = DeepSpeedStrategy(accelerator=CPUAccelerator())
|
|
with pytest.raises(RuntimeError, match="The DeepSpeed strategy is only supported on CUDA GPUs"):
|
|
strategy.setup_environment()
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_with_invalid_config_path():
|
|
"""Test to ensure if we pass an invalid config path we throw an exception."""
|
|
|
|
with pytest.raises(
|
|
FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist"
|
|
):
|
|
DeepSpeedStrategy(config="invalid_path.json")
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config):
|
|
"""Test to ensure if we pass an env variable, we load the config from the path."""
|
|
config_path = os.path.join(tmpdir, "temp.json")
|
|
with open(config_path, "w") as f:
|
|
f.write(json.dumps(deepspeed_config))
|
|
monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path)
|
|
strategy = DeepSpeedStrategy()
|
|
assert strategy.config == deepspeed_config
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_defaults():
|
|
"""Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed."""
|
|
strategy = DeepSpeedStrategy()
|
|
assert strategy.config is not None
|
|
assert isinstance(strategy.config["zero_optimization"], dict)
|
|
assert strategy._backward_sync_control is None
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
|
|
"""Ensure if we modify the activation checkpointing parameters, the deepspeed config contains these changes."""
|
|
ds = DeepSpeedStrategy(
|
|
partition_activations=True,
|
|
cpu_checkpointing=True,
|
|
contiguous_memory_optimization=True,
|
|
synchronize_checkpoint_boundary=True,
|
|
)
|
|
checkpoint_config = ds.config["activation_checkpointing"]
|
|
assert checkpoint_config["partition_activations"]
|
|
assert checkpoint_config["cpu_checkpointing"]
|
|
assert checkpoint_config["contiguous_memory_optimization"]
|
|
assert checkpoint_config["synchronize_checkpoint_boundary"]
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_config_zero_offload(deepspeed_zero_config):
|
|
"""Test the various ways optimizer-offloading can be configured."""
|
|
|
|
# default config
|
|
strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
|
|
assert "offload_optimizer" not in strategy.config["zero_optimization"]
|
|
|
|
# default config
|
|
strategy = DeepSpeedStrategy()
|
|
assert "offload_optimizer" not in strategy.config["zero_optimization"]
|
|
|
|
# default config with `offload_optimizer` argument override
|
|
strategy = DeepSpeedStrategy(offload_optimizer=True)
|
|
assert strategy.config["zero_optimization"]["offload_optimizer"] == {
|
|
"buffer_count": 4,
|
|
"device": "cpu",
|
|
"nvme_path": "/local_nvme",
|
|
"pin_memory": False,
|
|
}
|
|
|
|
# externally configured through config
|
|
deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False
|
|
strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
|
|
assert strategy.config["zero_optimization"]["offload_optimizer"] is False
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
@mock.patch("deepspeed.initialize")
|
|
def test_deepspeed_setup_module(init_mock):
|
|
"""Test that the DeepSpeed strategy can set up the model for inference (no optimizer required)."""
|
|
model = Mock()
|
|
model.parameters.return_value = []
|
|
strategy = DeepSpeedStrategy()
|
|
strategy.parallel_devices = [torch.device("cuda", 1)]
|
|
init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work
|
|
|
|
strategy.setup_module(model)
|
|
init_mock.assert_called_with(
|
|
args=ANY,
|
|
config=strategy.config,
|
|
model=model,
|
|
model_parameters=ANY,
|
|
optimizer=None,
|
|
dist_init_required=False,
|
|
)
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_requires_joint_setup():
|
|
"""Test that the DeepSpeed strategy does not support setting up model and optimizer independently."""
|
|
strategy = DeepSpeedStrategy()
|
|
with pytest.raises(
|
|
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
|
|
):
|
|
strategy.setup_optimizer(Mock())
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
|
|
"""Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
|
|
strategy = DeepSpeedStrategy()
|
|
with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
|
|
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
|
|
"""Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
|
|
# missing DeepSpeedEngine
|
|
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
|
strategy.save_checkpoint(path=tmp_path, state={})
|
|
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
|
|
|
# multiple DeepSpeedEngine
|
|
model1 = Mock(spec=torch.nn.Module)
|
|
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
|
model2 = Mock(spec=torch.nn.Module)
|
|
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
|
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
|
|
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
|
|
"""Test that the DeepSpeed engine and optimizer get separated from the client state."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
|
|
# Model only
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=None)
|
|
model.modules.return_value = [model]
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
|
|
# the client_state should not contain any deepspeed engine or deepspeed optimizer
|
|
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
|
|
|
|
# Model and optimizer
|
|
optimizer = Mock()
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
|
model.modules.return_value = [model]
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
|
|
# the client_state should not contain any deepspeed engine or deepspeed optimizer
|
|
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
|
|
"""Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
optimizer = Mock()
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
|
model.modules.return_value = [model]
|
|
# `mp_world_size` is an internal key
|
|
with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_load_checkpoint_no_state(tmp_path):
|
|
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
|
|
strategy = DeepSpeedStrategy()
|
|
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
|
|
strategy.load_checkpoint(path=tmp_path, state=None)
|
|
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
|
|
strategy.load_checkpoint(path=tmp_path, state={})
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
|
|
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
|
|
# missing DeepSpeedEngine
|
|
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
|
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
|
|
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
|
|
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
|
|
|
# multiple DeepSpeedEngine
|
|
model1 = Mock(spec=torch.nn.Module)
|
|
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
|
model2 = Mock(spec=torch.nn.Module)
|
|
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
|
|
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
|
|
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
|
|
"""Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
optimizer = Mock()
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
|
model.modules.return_value = [model]
|
|
|
|
# If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
|
|
# returns None from its function call
|
|
model.load_checkpoint.return_value = [None, None]
|
|
|
|
# Check for our custom user error
|
|
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
|
|
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
|
|
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
optimizer = Mock()
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
|
model.modules.return_value = [model]
|
|
|
|
# the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
|
|
loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
|
|
model.load_checkpoint.return_value = [None, loaded_client_state]
|
|
|
|
state = {"model": model, "user_data": {"iteration": 0}}
|
|
metadata = strategy.load_checkpoint(path=tmp_path, state=state)
|
|
|
|
# the user's state gets updated with the loaded value
|
|
assert state == {"model": model, "user_data": {"iteration": 5}}
|
|
# additional metadata gets separated from client state
|
|
assert metadata == {"deepspeed_metadata": "data"}
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
|
|
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
|
|
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
strategy = DeepSpeedStrategy()
|
|
optimizer = Mock(spec=Optimizer)
|
|
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
|
|
model.modules.return_value = [model]
|
|
|
|
# required, otherwise mock cannot be unpacked
|
|
model.load_checkpoint.return_value = [None, {}]
|
|
|
|
state = {"model": model}
|
|
if optimzer_state_requested:
|
|
state["optimizer"] = optimizer
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state=state)
|
|
model.load_checkpoint.assert_called_with(
|
|
tmp_path,
|
|
tag="checkpoint",
|
|
load_optimizer_states=optimzer_state_requested,
|
|
load_lr_scheduler_states=False,
|
|
load_module_strict=True,
|
|
)
|