Fix save/load/resume from checkpoint for DeepSpeed Plugin (#8397)

This commit is contained in:
Sean Naren 2021-08-02 23:31:05 +01:00 committed by GitHub
parent d01d8334b5
commit e5d9e21dea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 276 additions and 91 deletions

View File

@ -93,7 +93,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))
-
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
## [1.4.0] - 2021-07-27

View File

@ -507,7 +507,7 @@ class ModelCheckpoint(Callback):
def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
self._fs.rm(filepath)
self._fs.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:

View File

@ -29,13 +29,16 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache
warning_cache = WarningCache()
if _DEEPSPEED_AVAILABLE:
import deepspeed
@ -119,7 +122,7 @@ class DeepSpeedPlugin(DDPPlugin):
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
save_full_weights: bool = True,
load_full_weights: bool = False,
cpu_offload: bool = False,
cpu_offload_params: bool = False,
cpu_offload_use_pin_memory: bool = False,
@ -250,10 +253,9 @@ class DeepSpeedPlugin(DDPPlugin):
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
save_full_weights: Gathers weights across all processes before saving to disk
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
rather than individual sharded weight files.
Disable to save sharded states individually.
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
@ -313,7 +315,7 @@ class DeepSpeedPlugin(DDPPlugin):
deepspeed.utils.logging.logger.setLevel(logging_level)
self.remote_device = remote_device
self.save_full_weights = save_full_weights
self.load_full_weights = load_full_weights
# default FP16 parameters.
self.loss_scale = loss_scale
@ -365,6 +367,10 @@ class DeepSpeedPlugin(DDPPlugin):
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)
@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True
def pre_dispatch(self):
self.init_deepspeed()
self.barrier()
@ -657,13 +663,14 @@ class DeepSpeedPlugin(DDPPlugin):
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg
def _filepath_to_dir(self, filepath: str) -> str:
return os.path.dirname(filepath)
@property
def deepspeed_engine(self):
return self.model
@property
def _multi_device(self) -> bool:
return self.num_processes > 1 or self.num_nodes > 1
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
@ -671,29 +678,21 @@ class DeepSpeedPlugin(DDPPlugin):
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
"""
if self.world_size > 1 and self.zero_stage_3:
if self.save_full_weights:
# todo: expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
checkpoint["state_dict"] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
"If a single file is required after training, see <TODO> for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
else:
super().save_checkpoint(checkpoint, filepath)
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
if self.save_full_weights or self.world_size == 1:
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
@ -703,20 +702,78 @@ class DeepSpeedPlugin(DDPPlugin):
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(checkpoint_path)
if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()
_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)
if client_state is None:
raise MisconfigurationException(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`."
)
return client_state
@property
def lightning_restore_optimizer_and_schedulers(self) -> bool:
# managed by DeepSpeed
if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
rank_zero_warn(
"A single checkpoint file has been given. This means optimizer states and "
"scheduler states can not be restored. If you'd like to restore these states, you must "
"provide a path to the originally saved DeepSpeed checkpoint."
)
return False
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
pass
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)
def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
"""
Overrides the normal load_state_dict behaviour in PyTorch to ensure
we gather parameters that may be sharded across processes before loading
the state dictionary when using ZeRO stage 3.
This is then automatically synced across processes.
Args:
ckpt: The ckpt file.
"""
def load(module: torch.nn.Module, prefix=""):
missing_keys = []
unexpected_keys = []
error_msgs = []
state_dict = ckpt["state_dict"]
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if self.is_global_zero:
module._load_from_state_dict(
state_dict=state_dict,
prefix=prefix,
local_metadata=local_metadata,
strict=True,
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
error_msgs=error_msgs,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(self.lightning_module, prefix="")
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`

View File

@ -231,7 +231,9 @@ class TrainingTypePlugin(Plugin, ABC):
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
Returns:
If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

View File

@ -16,10 +16,14 @@ from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelChec
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
if _DEEPSPEED_AVAILABLE:
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
class ModelParallelBoringModel(BoringModel):
def __init__(self):
@ -329,8 +333,6 @@ def test_deepspeed_config(tmpdir, deepspeed_zero_config):
trainer.fit(model)
trainer.test(model)
_assert_save_model_is_equal(model, tmpdir, trainer)
@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_custom_precision_params(tmpdir):
@ -394,17 +396,13 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_co
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu(tmpdir, deepspeed_config):
def test_deepspeed_multigpu(tmpdir):
"""
Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation.
Test to ensure that DeepSpeed with multiple GPUs works.
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(zero_optimization=False, stage=2)],
gpus=2,
fast_dev_run=True,
precision=16,
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
)
trainer.fit(model)
trainer.test(model)
@ -419,6 +417,54 @@ def test_deepspeed_fp32_works(tmpdir):
trainer.fit(model)
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_stage_3_save_warning(tmpdir):
"""
Test to ensure that DeepSpeed Stage 3 gives a warning when saving.
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."):
trainer.save_checkpoint(checkpoint_path)
@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_single_file(tmpdir):
"""
Test to ensure that DeepSpeed loads from a single file checkpoint.
"""
model = BoringModel()
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint(checkpoint_path)
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16
)
plugin = trainer.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
assert not plugin.load_full_weights
with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."):
trainer.test(model, ckpt_path=checkpoint_path)
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(stage=3, load_full_weights=True)],
gpus=1,
fast_dev_run=True,
precision=16,
)
plugin = trainer.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
assert plugin.load_full_weights
trainer.test(model, ckpt_path=checkpoint_path)
class ModelParallelClassificationModel(LightningModule):
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
super().__init__()
@ -474,6 +520,10 @@ class ModelParallelClassificationModel(LightningModule):
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if not hasattr(self, "model"):
self.configure_sharded_model()
class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
@property
@ -504,7 +554,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
trainer.fit(model)
trainer.test(model)
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)
_assert_save_model_is_equal(model, tmpdir, trainer)
@RunIf(min_gpus=2, deepspeed=True, special=True)
@ -520,12 +570,10 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config
trainer.fit(model)
trainer.test(model)
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim)
_assert_save_model_is_equal(model, tmpdir, trainer)
def run_checkpoint_test(
tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2
):
def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2):
seed_everything(1)
if automatic_optimization:
model = ModelParallelClassificationModel()
@ -535,9 +583,8 @@ def run_checkpoint_test(
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=10,
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
plugins=[DeepSpeedPlugin(stage=3)],
gpus=2,
precision=16,
accumulate_grad_batches=accumulate_grad_batches,
@ -545,29 +592,20 @@ def run_checkpoint_test(
)
trainer.fit(model, datamodule=dm)
results = trainer.test(model, datamodule=dm)
results = trainer.test(datamodule=dm)
assert results[0]["test_acc"] > 0.7
saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
assert saved_results[0]["test_acc"] > 0.7
assert saved_results == results
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=10,
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
gpus=2,
precision=16,
accumulate_grad_batches=2,
callbacks=[ck],
resume_from_checkpoint=ck.best_model_path,
)
results = trainer.test(model, datamodule=dm)
assert results[0]["test_acc"] > 0.7
if automatic_optimization:
model = ModelParallelClassificationModel()
else:
model = ManualModelParallelClassificationModel()
trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16)
dm.predict_dataloader = dm.test_dataloader
results = trainer.predict(datamodule=dm)
assert results[-1] > 0.7
results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path)
assert results[0]["test_acc"] > 0.7
@RunIf(min_gpus=2, deepspeed=True, special=True)
@ -576,16 +614,94 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
and see convergence.
"""
run_checkpoint_test(tmpdir, save_full_weights=False)
run_checkpoint_test(tmpdir)
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):
@RunIf(min_gpus=1, deepspeed=True, special=False)
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
where we save the full weights to one file.
Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning
that the optimizer state and scheduler states cannot be restored.
"""
run_checkpoint_test(tmpdir, save_full_weights=True)
dm = ClassifDataModule()
model = BoringModel()
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint(checkpoint_path)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins=DeepSpeedPlugin(stage=3, load_full_weights=True),
gpus=1,
precision=16,
resume_from_checkpoint=checkpoint_path,
)
with pytest.warns(
UserWarning,
match="A single checkpoint file has been given. This means optimizer states and "
"scheduler states can not be restored. If you'd like to restore these states, you must "
"provide a path to the originally saved DeepSpeed checkpoint.",
):
trainer.fit(model, datamodule=dm)
@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_resume_training(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can resume training.
"""
initial_model = ModelParallelClassificationModel()
dm = ClassifDataModule()
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
initial_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
plugins=DeepSpeedPlugin(stage=3),
gpus=1,
precision=16,
callbacks=[ck],
)
initial_trainer.fit(initial_model, datamodule=dm)
class TestCallback(Callback):
def on_train_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin
current_deepspeed_plugin = trainer.accelerator.training_type_plugin
assert isinstance(original_deepspeed_plugin, DeepSpeedPlugin)
assert isinstance(current_deepspeed_plugin, DeepSpeedPlugin)
# assert optimizer states are the correctly loaded
original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict()
current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict()
for orig_tensor, current_tensor in zip(
original_optimizer_dict["fp32_flat_groups"], current_optimizer_dict["fp32_flat_groups"]
):
assert torch.all(orig_tensor.eq(current_tensor))
# assert model state is loaded correctly
for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()):
assert torch.equal(current_param.cpu(), initial_param.cpu())
# assert epoch has correctly been restored
assert trainer.current_epoch == 1
model = ModelParallelClassificationModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins=DeepSpeedPlugin(stage=3),
gpus=1,
precision=16,
resume_from_checkpoint=ck.best_model_path,
callbacks=TestCallback(),
)
trainer.fit(model, datamodule=dm)
@RunIf(min_gpus=2, deepspeed=True, special=True)
@ -594,7 +710,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir):
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
where we save the full weights to one file.
"""
run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1)
run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1)
@RunIf(min_gpus=2, deepspeed=True, special=True)
@ -680,18 +796,25 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat
assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank)
def _assert_save_model_is_equal(model, tmpdir, trainer, cls=BoringModel):
def _assert_save_model_is_equal(model, tmpdir, trainer):
checkpoint_path = os.path.join(tmpdir, "model.pt")
checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.accelerator.barrier()
# carry out the check only on rank 0
if trainer.global_rank == 0:
saved_model = cls.load_from_checkpoint(checkpoint_path)
if model.dtype == torch.half:
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16
if trainer.is_global_zero:
single_ckpt_path = os.path.join(tmpdir, "single_model.pt")
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
state_dict = torch.load(single_ckpt_path)
model = model.cpu()
# Assert model parameters are identical after loading
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(orig_param, trained_model_param)
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
if model.dtype == torch.half:
# moved model to float32 for comparison with single fp32 saved weights
saved_model_param = saved_model_param.half()
assert torch.equal(orig_param, saved_model_param)
@RunIf(min_gpus=2, deepspeed=True, special=True)
@ -705,4 +828,4 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir):
)
trainer.fit(model)
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelNoSchedulers)
_assert_save_model_is_equal(model, tmpdir, trainer)