2022-09-29 16:39:32 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
from copy import deepcopy
|
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric import Fabric
|
|
|
|
from lightning.fabric.plugins import DeepSpeedPrecision
|
|
|
|
from lightning.fabric.strategies import DeepSpeedStrategy
|
2023-03-03 16:55:48 +00:00
|
|
|
from tests_fabric.helpers.models import BoringFabric, RandomDataset, RandomIterableDataset
|
|
|
|
from tests_fabric.helpers.runif import RunIf
|
|
|
|
from tests_fabric.test_fabric import BoringModel
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
|
|
|
def test_deepspeed_multiple_models():
|
2023-05-21 13:04:01 +00:00
|
|
|
fabric = Fabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu")
|
|
|
|
fabric.launch()
|
|
|
|
|
|
|
|
with fabric.init_module():
|
2023-05-19 17:42:49 +00:00
|
|
|
model = BoringModel()
|
|
|
|
|
2023-05-21 13:04:01 +00:00
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
|
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
|
|
|
|
|
|
for i in range(2):
|
|
|
|
optimizer.zero_grad()
|
|
|
|
x = model(torch.randn(1, 32).to(fabric.device))
|
|
|
|
loss = x.sum()
|
|
|
|
if i == 0:
|
|
|
|
# the weights are not initialized with stage 3 until backward is run once
|
|
|
|
assert all(w.nelement() == 0 for w in model.state_dict().values())
|
|
|
|
fabric.backward(loss, model=model)
|
|
|
|
if i == 0:
|
|
|
|
# save for later to check that the weights were updated
|
|
|
|
state_dict = deepcopy(model.state_dict())
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# check that the model trained, the weights from step 1 do not match the weights from step 2
|
|
|
|
for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
|
|
|
|
assert not torch.allclose(mw_b, mw_a)
|
|
|
|
|
|
|
|
fabric.seed_everything(42)
|
|
|
|
model_1 = BoringModel()
|
|
|
|
optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
|
|
|
|
|
|
|
|
fabric.seed_everything(42)
|
|
|
|
model_2 = BoringModel()
|
|
|
|
optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
|
|
|
|
|
|
|
|
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
|
|
|
assert torch.allclose(mw_1, mw_2)
|
|
|
|
|
|
|
|
model_1, optimizer_1 = fabric.setup(model_1, optimizer_1)
|
|
|
|
model_2, optimizer_2 = fabric.setup(model_2, optimizer_2)
|
|
|
|
|
|
|
|
# train model_1 first
|
|
|
|
fabric.seed_everything(42)
|
|
|
|
data_list = []
|
|
|
|
for _ in range(2):
|
|
|
|
optimizer_1.zero_grad()
|
|
|
|
data = torch.randn(1, 32).to(fabric.device)
|
|
|
|
data_list.append(data)
|
|
|
|
x = model_1(data)
|
|
|
|
loss = x.sum()
|
|
|
|
fabric.backward(loss, model=model_1)
|
|
|
|
optimizer_1.step()
|
|
|
|
|
|
|
|
# the weights do not match
|
|
|
|
assert all(w.nelement() > 1 for w in model_1.state_dict().values())
|
|
|
|
assert all(w.nelement() == 0 for w in model_2.state_dict().values())
|
|
|
|
|
|
|
|
# now train model_2 with the same data
|
|
|
|
for data in data_list:
|
|
|
|
optimizer_2.zero_grad()
|
|
|
|
x = model_2(data)
|
|
|
|
loss = x.sum()
|
|
|
|
fabric.backward(loss, model=model_2)
|
|
|
|
optimizer_2.step()
|
|
|
|
|
|
|
|
# the weights should match
|
|
|
|
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
|
|
|
assert torch.allclose(mw_1, mw_2)
|
|
|
|
|
|
|
|
# Verify collectives works as expected
|
|
|
|
ranks = fabric.all_gather(torch.tensor([fabric.local_rank]).to(fabric.device))
|
|
|
|
assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
|
|
|
|
assert fabric.broadcast(True)
|
|
|
|
assert fabric.is_global_zero == (fabric.local_rank == 0)
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
|
|
|
@pytest.mark.parametrize(
|
2023-05-04 15:50:39 +00:00
|
|
|
("dataset_cls", "logging_batch_size_per_gpu", "expected_batch_size"),
|
2022-09-29 16:39:32 +00:00
|
|
|
[
|
|
|
|
(RandomDataset, None, 1),
|
|
|
|
(RandomDataset, 10, 10),
|
|
|
|
(RandomIterableDataset, None, 1),
|
|
|
|
(RandomIterableDataset, 10, 10),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size):
|
|
|
|
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
|
2023-05-19 17:42:49 +00:00
|
|
|
fabric = Fabric(
|
2022-09-29 16:39:32 +00:00
|
|
|
accelerator="cuda",
|
|
|
|
devices=1,
|
|
|
|
strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False),
|
|
|
|
)
|
2023-05-21 13:04:01 +00:00
|
|
|
fabric.launch()
|
|
|
|
assert isinstance(fabric._strategy, DeepSpeedStrategy)
|
|
|
|
_ = fabric.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
|
|
|
|
config = fabric._strategy.config
|
|
|
|
assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
|
|
|
def test_deepspeed_configure_optimizers():
|
|
|
|
"""Test that the deepspeed strategy with default initialization wraps the optimizer correctly."""
|
|
|
|
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
|
|
|
2023-05-19 17:42:49 +00:00
|
|
|
fabric = Fabric(
|
2022-09-29 16:39:32 +00:00
|
|
|
strategy=DeepSpeedStrategy(),
|
|
|
|
accelerator="cuda",
|
|
|
|
devices=1,
|
2023-02-17 10:41:18 +00:00
|
|
|
precision="16-mixed",
|
2022-09-29 16:39:32 +00:00
|
|
|
)
|
2023-05-21 13:04:01 +00:00
|
|
|
fabric.launch()
|
|
|
|
model = nn.Linear(3, 3)
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
|
|
assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
|
|
|
|
assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
|
|
|
def test_deepspeed_custom_precision_params():
|
|
|
|
"""Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these
|
|
|
|
changes."""
|
|
|
|
strategy = DeepSpeedStrategy(
|
|
|
|
loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14
|
|
|
|
)
|
2023-05-19 17:42:49 +00:00
|
|
|
fabric = Fabric(
|
2022-09-29 16:39:32 +00:00
|
|
|
strategy=strategy,
|
2023-02-17 10:41:18 +00:00
|
|
|
precision="16-mixed",
|
2022-09-29 16:39:32 +00:00
|
|
|
accelerator="cuda",
|
|
|
|
devices=1,
|
|
|
|
)
|
2023-05-21 13:04:01 +00:00
|
|
|
fabric.launch()
|
|
|
|
assert fabric._strategy._config_initialized
|
|
|
|
assert fabric._strategy.config["fp16"]["loss_scale"] == 10
|
|
|
|
assert fabric._strategy.config["fp16"]["initial_scale_power"] == 11
|
|
|
|
assert fabric._strategy.config["fp16"]["loss_scale_window"] == 12
|
|
|
|
assert fabric._strategy.config["fp16"]["hysteresis"] == 13
|
|
|
|
assert fabric._strategy.config["fp16"]["min_loss_scale"] == 14
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
|
|
|
def test_deepspeed_custom_activation_checkpointing_params_forwarded():
|
|
|
|
"""Test that the activation checkpointing parameters get passed to `deepspeed.checkpointing.configure`
|
|
|
|
correctly."""
|
|
|
|
import deepspeed
|
|
|
|
|
|
|
|
strategy = DeepSpeedStrategy(
|
|
|
|
partition_activations=True,
|
|
|
|
cpu_checkpointing=True,
|
|
|
|
contiguous_memory_optimization=True,
|
|
|
|
synchronize_checkpoint_boundary=True,
|
|
|
|
)
|
2023-05-19 17:42:49 +00:00
|
|
|
fabric = Fabric(
|
2022-09-29 16:39:32 +00:00
|
|
|
strategy=strategy,
|
2023-02-17 10:41:18 +00:00
|
|
|
precision="16-mixed",
|
2022-09-29 16:39:32 +00:00
|
|
|
accelerator="cuda",
|
|
|
|
devices=1,
|
|
|
|
)
|
2023-05-21 13:04:01 +00:00
|
|
|
fabric.launch()
|
|
|
|
model = nn.Linear(3, 3)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
|
|
|
|
with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
|
|
|
|
fabric.setup(model, optimizer)
|
|
|
|
|
|
|
|
configure.assert_called_with(
|
|
|
|
mpu_=None,
|
|
|
|
partition_activations=True,
|
|
|
|
contiguous_checkpointing=True,
|
|
|
|
checkpoint_in_cpu=True,
|
|
|
|
profile=None,
|
|
|
|
)
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class ModelParallelClassification(BoringFabric):
|
2022-09-29 16:39:32 +00:00
|
|
|
num_blocks = 5
|
|
|
|
|
|
|
|
def get_model(self):
|
|
|
|
return nn.Sequential(*(self._make_block() for _ in range(self.num_blocks)), nn.Linear(32, 3))
|
|
|
|
|
|
|
|
def step(self, model, batch):
|
|
|
|
x = batch
|
|
|
|
y = torch.ones(batch.size(0), device=batch.device, dtype=torch.long)
|
|
|
|
x = model(x)
|
|
|
|
# Ensure output is in float32 for softmax operation
|
|
|
|
x = x.float()
|
|
|
|
logits = F.softmax(x, dim=1)
|
2023-05-05 09:34:40 +00:00
|
|
|
return F.cross_entropy(logits, y)
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
def _make_block(self):
|
|
|
|
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
2023-01-24 17:53:26 +00:00
|
|
|
def test_deepspeed_multigpu_stage_3():
|
2022-09-29 16:39:32 +00:00
|
|
|
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = ModelParallelClassification(
|
2022-09-29 16:39:32 +00:00
|
|
|
strategy=DeepSpeedStrategy(stage=3),
|
|
|
|
accelerator="cuda",
|
|
|
|
devices=2,
|
2023-02-17 10:41:18 +00:00
|
|
|
precision="16-mixed",
|
2022-09-29 16:39:32 +00:00
|
|
|
)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.run()
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
|
|
@mock.patch("deepspeed.init_distributed", autospec=True)
|
2023-02-01 20:34:38 +00:00
|
|
|
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
|
2022-09-29 16:39:32 +00:00
|
|
|
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
|
2023-01-22 17:57:24 +00:00
|
|
|
def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
|
2022-09-29 16:39:32 +00:00
|
|
|
"""Test to ensure that we set up distributed communication correctly.
|
|
|
|
|
|
|
|
When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this.
|
|
|
|
"""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = BoringFabric(strategy=DeepSpeedStrategy(stage=3))
|
|
|
|
strategy = fabric._strategy
|
2022-09-29 16:39:32 +00:00
|
|
|
assert isinstance(strategy, DeepSpeedStrategy)
|
|
|
|
with mock.patch("platform.system", return_value=platform) as platform_mock:
|
|
|
|
strategy._init_deepspeed_distributed()
|
|
|
|
deepspeed_dist_mock.assert_called()
|
|
|
|
platform_mock.assert_called()
|
|
|
|
if platform == "Windows":
|
|
|
|
# assert no env variables have been set within the DeepSpeedStrategy
|
|
|
|
assert all(k not in os.environ for k in ("MASTER_PORT", "MASTER_ADDR", "RANK", "WORLD_SIZE", "LOCAL_RANK"))
|
|
|
|
else:
|
|
|
|
assert os.environ["MASTER_ADDR"] == str(strategy.cluster_environment.main_address)
|
|
|
|
assert os.environ["MASTER_PORT"] == str(strategy.cluster_environment.main_port)
|
|
|
|
assert os.environ["RANK"] == str(strategy.global_rank)
|
|
|
|
assert os.environ["WORLD_SIZE"] == str(strategy.world_size)
|
|
|
|
assert os.environ["LOCAL_RANK"] == str(strategy.local_rank)
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
2023-01-24 17:53:26 +00:00
|
|
|
def test_deepspeed_specific_gpu_device_index():
|
2022-09-29 16:39:32 +00:00
|
|
|
"""Test that the DeepSpeed strategy can run on specific device indices."""
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class RunFabric(BoringFabric):
|
2022-09-29 16:39:32 +00:00
|
|
|
def step(self, model, batch):
|
|
|
|
assert self.device.type == "cuda"
|
|
|
|
assert self.device.index == 1
|
|
|
|
assert batch.device.index == 1
|
|
|
|
assert model.device.index == 1
|
|
|
|
return super().step(model, batch)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = RunFabric(accelerator="cuda", devices=[1], strategy="deepspeed")
|
|
|
|
fabric.run()
|
2022-09-29 16:39:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
|
2023-01-24 17:53:26 +00:00
|
|
|
def test_deepspeed_with_bfloat16_precision():
|
2022-09-29 16:39:32 +00:00
|
|
|
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer = nn.Linear(32, 2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
assert x.dtype == torch.bfloat16
|
|
|
|
return self.layer(x)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class RunFabric(BoringFabric):
|
2022-09-29 16:39:32 +00:00
|
|
|
def get_model(self):
|
|
|
|
return Model()
|
|
|
|
|
|
|
|
def step(self, model, batch):
|
|
|
|
assert self._strategy.config["bf16"]["enabled"]
|
|
|
|
assert batch.dtype == torch.float32
|
|
|
|
assert model.layer.weight.dtype == torch.bfloat16
|
|
|
|
return super().step(model, batch)
|
|
|
|
|
2023-02-17 10:41:18 +00:00
|
|
|
fabric = RunFabric(accelerator="cuda", devices=2, strategy="deepspeed_stage_3", precision="bf16-mixed")
|
2023-01-10 15:02:05 +00:00
|
|
|
assert isinstance(fabric._strategy.precision, DeepSpeedPrecision)
|
2023-02-17 10:41:18 +00:00
|
|
|
assert fabric._strategy.precision.precision == "bf16-mixed"
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric._strategy.config["zero_optimization"]["stage"] == 3
|
|
|
|
fabric.run()
|
2023-01-24 17:53:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
|
|
|
|
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
|
|
|
|
weights in float32 precision."""
|
|
|
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
|
|
|
|
|
|
|
assert isinstance(fabric.strategy, DeepSpeedStrategy)
|
|
|
|
|
|
|
|
# carry out the check only on rank 0
|
|
|
|
if fabric.is_global_zero:
|
|
|
|
if fabric.strategy.config["zero_optimization"]["stage"] in (2, 3):
|
|
|
|
single_ckpt_path = checkpoint_path / "single_model.pt"
|
|
|
|
# the tag is hardcoded in DeepSpeedStrategy
|
|
|
|
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
|
|
|
|
state_dict = torch.load(single_ckpt_path)
|
|
|
|
else:
|
|
|
|
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
|
|
|
|
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
|
|
|
|
state_dict = torch.load(single_ckpt_path)["module"]
|
|
|
|
|
|
|
|
model = model.cpu()
|
|
|
|
|
|
|
|
# assert model parameters are identical after loading
|
|
|
|
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
|
|
|
|
# perform the equality check in the same precision
|
|
|
|
saved_model_param = saved_model_param.cpu().to(orig_param.dtype)
|
|
|
|
assert torch.equal(orig_param, saved_model_param)
|
|
|
|
|
|
|
|
fabric.barrier()
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
|
|
|
|
@pytest.mark.parametrize("stage", [1, 2, 3])
|
|
|
|
def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
|
|
|
|
"""Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully."""
|
|
|
|
from deepspeed import DeepSpeedEngine
|
|
|
|
|
2023-02-17 10:41:18 +00:00
|
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16-mixed")
|
2023-01-24 17:53:26 +00:00
|
|
|
fabric.launch()
|
|
|
|
|
|
|
|
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
|
|
|
|
|
2023-04-26 15:25:33 +00:00
|
|
|
with fabric.init_module():
|
2023-01-24 17:53:26 +00:00
|
|
|
model = BoringModel()
|
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
|
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
|
|
assert isinstance(model._forward_module, DeepSpeedEngine)
|
|
|
|
|
|
|
|
# TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16
|
|
|
|
assert model.dtype == torch.float32
|
|
|
|
assert next(model.parameters()).dtype == torch.bfloat16
|
|
|
|
|
|
|
|
# dummy training step
|
|
|
|
output = model(torch.randn(1, 32).to(fabric.device))
|
|
|
|
loss = output.sum()
|
|
|
|
fabric.backward(loss)
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
|
|
|
|
# re-init all objects and resume
|
|
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
|
|
|
|
fabric.launch()
|
2023-04-26 15:25:33 +00:00
|
|
|
with fabric.init_module():
|
2023-01-24 17:53:26 +00:00
|
|
|
model = BoringModel()
|
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
|
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
|
|
state = {"model": model, "optimizer": optimizer, "steps": 0}
|
|
|
|
|
|
|
|
metadata = fabric.load(checkpoint_path, state)
|
|
|
|
|
|
|
|
# check user data in state reloaded
|
|
|
|
assert state["steps"] == 1
|
|
|
|
# the remainder of the deepspeed checkpoint contains metadata
|
|
|
|
assert "ds_version" in metadata
|
|
|
|
|
|
|
|
_assert_saved_model_is_equal(fabric, model, checkpoint_path)
|