diff --git a/CHANGELOG.md b/CHANGELOG.md index 2004c21bdc..e0d5ac7caf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,7 +93,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613)) -- +- Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( + [#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397), + [#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644), + [#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627)) ## [1.4.0] - 2021-07-27 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 55111acd31..8ba3c51f68 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -507,7 +507,7 @@ class ModelCheckpoint(Callback): def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None: if trainer.should_rank_save_checkpoint and self._fs.exists(filepath): - self._fs.rm(filepath) + self._fs.rm(filepath, recursive=True) log.debug(f"Removed checkpoint: {filepath}") def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index eebd450db8..ee023e4d40 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -29,13 +29,16 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.types import LRSchedulerTypeTuple -from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning +from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache + +warning_cache = WarningCache() if _DEEPSPEED_AVAILABLE: import deepspeed @@ -119,7 +122,7 @@ class DeepSpeedPlugin(DDPPlugin): cpu_checkpointing: bool = False, contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, - save_full_weights: bool = True, + load_full_weights: bool = False, cpu_offload: bool = False, cpu_offload_params: bool = False, cpu_offload_use_pin_memory: bool = False, @@ -250,10 +253,9 @@ class DeepSpeedPlugin(DDPPlugin): synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. - save_full_weights: Gathers weights across all processes before saving to disk - when using ZeRO Stage 3. This allows a single weight file to contain the entire model, - rather than individual sharded weight files. - Disable to save sharded states individually. + load_full_weights: True when loading a single checkpoint file containing the model state dict + when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards + per worker. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -313,7 +315,7 @@ class DeepSpeedPlugin(DDPPlugin): deepspeed.utils.logging.logger.setLevel(logging_level) self.remote_device = remote_device - self.save_full_weights = save_full_weights + self.load_full_weights = load_full_weights # default FP16 parameters. self.loss_scale = loss_scale @@ -365,6 +367,10 @@ class DeepSpeedPlugin(DDPPlugin): os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(self.local_rank) + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + return True + def pre_dispatch(self): self.init_deepspeed() self.barrier() @@ -657,13 +663,14 @@ class DeepSpeedPlugin(DDPPlugin): cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg - def _filepath_to_dir(self, filepath: str) -> str: - return os.path.dirname(filepath) - @property def deepspeed_engine(self): return self.model + @property + def _multi_device(self) -> bool: + return self.num_processes > 1 or self.num_nodes > 1 + def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -671,29 +678,21 @@ class DeepSpeedPlugin(DDPPlugin): checkpoint: The checkpoint state dictionary filepath: write-target file's path """ - if self.world_size > 1 and self.zero_stage_3: - if self.save_full_weights: - # todo: expose this as general function in deepspeed - state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict() - if self.is_global_zero: - # State dict keys will include reference to wrapper LightningDeepSpeedModule - # Delete `module` prefix before saving. - state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()} - checkpoint["state_dict"] = state_dict - return super().save_checkpoint(checkpoint, filepath) - return + if self.zero_stage_3 and self._multi_device and self.is_global_zero: + # todo (sean): Add link to docs once docs are merged. + warning_cache.warn( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + "If a single file is required after training, see for instructions." + ) + # Use deepspeed's internal checkpointing function to handle partitioned weights across processes + # dump states as a checkpoint dictionary object + _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] + checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} + self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes - # dump states as a checkpoint dictionary object - save_dir = self._filepath_to_dir(filepath) - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint) - else: - super().save_checkpoint(checkpoint, filepath) - - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - if self.save_full_weights or self.world_size == 1: + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) @@ -703,20 +702,78 @@ class DeepSpeedPlugin(DDPPlugin): from pytorch_lightning.trainer.states import TrainerFn is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING - save_dir = self._filepath_to_dir(checkpoint_path) - - if self.zero_stage_3: - # TODO: Currently required as this call is missing within the deepspeed engine. - self.deepspeed_engine.optimizer._partition_all_parameters() - _, client_state = self.deepspeed_engine.load_checkpoint( - save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting + checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting ) + if client_state is None: + raise MisconfigurationException( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " + "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." + ) return client_state + @property + def lightning_restore_optimizer_and_schedulers(self) -> bool: + # managed by DeepSpeed + if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + rank_zero_warn( + "A single checkpoint file has been given. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must " + "provide a path to the originally saved DeepSpeed checkpoint." + ) + return False + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` - pass + if self.load_full_weights and self.zero_stage_3: + self.model_to_device() + self._restore_zero_state(checkpoint) + + def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: + """ + Overrides the normal load_state_dict behaviour in PyTorch to ensure + we gather parameters that may be sharded across processes before loading + the state dictionary when using ZeRO stage 3. + This is then automatically synced across processes. + + Args: + ckpt: The ckpt file. + """ + + def load(module: torch.nn.Module, prefix=""): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + state_dict = ckpt["state_dict"] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=True, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs, + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self.lightning_module, prefix="") def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index afe9f119b6..a8b444de0b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -231,7 +231,9 @@ class TrainingTypePlugin(Plugin, ABC): Override to delay setting optimizers and schedulers till after dispatch. This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. - Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + + Returns: + If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 4107fb70df..f0c1d7d49b 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -16,10 +16,14 @@ from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelChec from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf +if _DEEPSPEED_AVAILABLE: + from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + class ModelParallelBoringModel(BoringModel): def __init__(self): @@ -329,8 +333,6 @@ def test_deepspeed_config(tmpdir, deepspeed_zero_config): trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer) - @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_custom_precision_params(tmpdir): @@ -394,17 +396,13 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_co @RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu(tmpdir, deepspeed_config): +def test_deepspeed_multigpu(tmpdir): """ - Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation. + Test to ensure that DeepSpeed with multiple GPUs works. """ model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, - plugins=[DeepSpeedPlugin(zero_optimization=False, stage=2)], - gpus=2, - fast_dev_run=True, - precision=16, + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 ) trainer.fit(model) trainer.test(model) @@ -419,6 +417,54 @@ def test_deepspeed_fp32_works(tmpdir): trainer.fit(model) +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_stage_3_save_warning(tmpdir): + """ + Test to ensure that DeepSpeed Stage 3 gives a warning when saving. + """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + checkpoint_path = os.path.join(tmpdir, "model.pt") + with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."): + trainer.save_checkpoint(checkpoint_path) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_single_file(tmpdir): + """ + Test to ensure that DeepSpeed loads from a single file checkpoint. + """ + model = BoringModel() + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + trainer.save_checkpoint(checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16 + ) + plugin = trainer.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert not plugin.load_full_weights + with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."): + trainer.test(model, ckpt_path=checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(stage=3, load_full_weights=True)], + gpus=1, + fast_dev_run=True, + precision=16, + ) + plugin = trainer.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.load_full_weights + trainer.test(model, ckpt_path=checkpoint_path) + + class ModelParallelClassificationModel(LightningModule): def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() @@ -474,6 +520,10 @@ class ModelParallelClassificationModel(LightningModule): lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if not hasattr(self, "model"): + self.configure_sharded_model() + class ManualModelParallelClassificationModel(ModelParallelClassificationModel): @property @@ -504,7 +554,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel) + _assert_save_model_is_equal(model, tmpdir, trainer) @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -520,12 +570,10 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config trainer.fit(model) trainer.test(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim) + _assert_save_model_is_equal(model, tmpdir, trainer) -def run_checkpoint_test( - tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 -): +def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2): seed_everything(1) if automatic_optimization: model = ModelParallelClassificationModel() @@ -535,9 +583,8 @@ def run_checkpoint_test( ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, - progress_bar_refresh_rate=0, max_epochs=10, - plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], + plugins=[DeepSpeedPlugin(stage=3)], gpus=2, precision=16, accumulate_grad_batches=accumulate_grad_batches, @@ -545,29 +592,20 @@ def run_checkpoint_test( ) trainer.fit(model, datamodule=dm) - results = trainer.test(model, datamodule=dm) + results = trainer.test(datamodule=dm) assert results[0]["test_acc"] > 0.7 - saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm) assert saved_results[0]["test_acc"] > 0.7 assert saved_results == results - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=10, - plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], - gpus=2, - precision=16, - accumulate_grad_batches=2, - callbacks=[ck], - resume_from_checkpoint=ck.best_model_path, - ) - results = trainer.test(model, datamodule=dm) - assert results[0]["test_acc"] > 0.7 + if automatic_optimization: + model = ModelParallelClassificationModel() + else: + model = ManualModelParallelClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16) - dm.predict_dataloader = dm.test_dataloader - results = trainer.predict(datamodule=dm) - assert results[-1] > 0.7 + results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path) + assert results[0]["test_acc"] > 0.7 @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -576,16 +614,94 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and see convergence. """ - run_checkpoint_test(tmpdir, save_full_weights=False) + run_checkpoint_test(tmpdir) -@RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): +@RunIf(min_gpus=1, deepspeed=True, special=False) +def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """ - Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, - where we save the full weights to one file. + Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning + that the optimizer state and scheduler states cannot be restored. """ - run_checkpoint_test(tmpdir, save_full_weights=True) + dm = ClassifDataModule() + model = BoringModel() + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + trainer.save_checkpoint(checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=DeepSpeedPlugin(stage=3, load_full_weights=True), + gpus=1, + precision=16, + resume_from_checkpoint=checkpoint_path, + ) + with pytest.warns( + UserWarning, + match="A single checkpoint file has been given. This means optimizer states and " + "scheduler states can not be restored. If you'd like to restore these states, you must " + "provide a path to the originally saved DeepSpeed checkpoint.", + ): + trainer.fit(model, datamodule=dm) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): + """ + Test to ensure with Stage 3 and multiple GPUs that we can resume training. + """ + initial_model = ModelParallelClassificationModel() + dm = ClassifDataModule() + + ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) + initial_trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + plugins=DeepSpeedPlugin(stage=3), + gpus=1, + precision=16, + callbacks=[ck], + ) + initial_trainer.fit(initial_model, datamodule=dm) + + class TestCallback(Callback): + def on_train_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin + current_deepspeed_plugin = trainer.accelerator.training_type_plugin + + assert isinstance(original_deepspeed_plugin, DeepSpeedPlugin) + assert isinstance(current_deepspeed_plugin, DeepSpeedPlugin) + # assert optimizer states are the correctly loaded + original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() + current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict() + for orig_tensor, current_tensor in zip( + original_optimizer_dict["fp32_flat_groups"], current_optimizer_dict["fp32_flat_groups"] + ): + assert torch.all(orig_tensor.eq(current_tensor)) + # assert model state is loaded correctly + for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()): + assert torch.equal(current_param.cpu(), initial_param.cpu()) + # assert epoch has correctly been restored + assert trainer.current_epoch == 1 + + model = ModelParallelClassificationModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=DeepSpeedPlugin(stage=3), + gpus=1, + precision=16, + resume_from_checkpoint=ck.best_model_path, + callbacks=TestCallback(), + ) + trainer.fit(model, datamodule=dm) @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -594,7 +710,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, where we save the full weights to one file. """ - run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1) + run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1) @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -680,18 +796,25 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank) -def _assert_save_model_is_equal(model, tmpdir, trainer, cls=BoringModel): +def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, "model.pt") + checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) + trainer.accelerator.barrier() + # carry out the check only on rank 0 - if trainer.global_rank == 0: - saved_model = cls.load_from_checkpoint(checkpoint_path) - if model.dtype == torch.half: - saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16 + if trainer.is_global_zero: + single_ckpt_path = os.path.join(tmpdir, "single_model.pt") + convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path) + state_dict = torch.load(single_ckpt_path) + model = model.cpu() # Assert model parameters are identical after loading - for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(orig_param, trained_model_param) + for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()): + if model.dtype == torch.half: + # moved model to float32 for comparison with single fp32 saved weights + saved_model_param = saved_model_param.half() + assert torch.equal(orig_param, saved_model_param) @RunIf(min_gpus=2, deepspeed=True, special=True) @@ -705,4 +828,4 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): ) trainer.fit(model) - _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelNoSchedulers) + _assert_save_model_is_equal(model, tmpdir, trainer)