Fix save/load/resume from checkpoint for DeepSpeed Plugin (#8397)
This commit is contained in:
parent
d01d8334b5
commit
e5d9e21dea
|
@ -93,7 +93,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))
|
||||
|
||||
|
||||
-
|
||||
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
|
||||
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
|
||||
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),
|
||||
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
|
||||
|
||||
|
||||
## [1.4.0] - 2021-07-27
|
||||
|
|
|
@ -507,7 +507,7 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
|
||||
self._fs.rm(filepath)
|
||||
self._fs.rm(filepath, recursive=True)
|
||||
log.debug(f"Removed checkpoint: {filepath}")
|
||||
|
||||
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
|
|
|
@ -29,13 +29,16 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
|
||||
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
|
||||
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
import deepspeed
|
||||
|
@ -119,7 +122,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
cpu_checkpointing: bool = False,
|
||||
contiguous_memory_optimization: bool = False,
|
||||
synchronize_checkpoint_boundary: bool = False,
|
||||
save_full_weights: bool = True,
|
||||
load_full_weights: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
cpu_offload_params: bool = False,
|
||||
cpu_offload_use_pin_memory: bool = False,
|
||||
|
@ -250,10 +253,9 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
|
||||
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
|
||||
|
||||
save_full_weights: Gathers weights across all processes before saving to disk
|
||||
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
|
||||
rather than individual sharded weight files.
|
||||
Disable to save sharded states individually.
|
||||
load_full_weights: True when loading a single checkpoint file containing the model state dict
|
||||
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
|
||||
per worker.
|
||||
"""
|
||||
if not _DEEPSPEED_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
|
@ -313,7 +315,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
deepspeed.utils.logging.logger.setLevel(logging_level)
|
||||
|
||||
self.remote_device = remote_device
|
||||
self.save_full_weights = save_full_weights
|
||||
self.load_full_weights = load_full_weights
|
||||
|
||||
# default FP16 parameters.
|
||||
self.loss_scale = loss_scale
|
||||
|
@ -365,6 +367,10 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
|
||||
@property
|
||||
def restore_checkpoint_after_pre_dispatch(self) -> bool:
|
||||
return True
|
||||
|
||||
def pre_dispatch(self):
|
||||
self.init_deepspeed()
|
||||
self.barrier()
|
||||
|
@ -657,13 +663,14 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
|
||||
return cfg
|
||||
|
||||
def _filepath_to_dir(self, filepath: str) -> str:
|
||||
return os.path.dirname(filepath)
|
||||
|
||||
@property
|
||||
def deepspeed_engine(self):
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def _multi_device(self) -> bool:
|
||||
return self.num_processes > 1 or self.num_nodes > 1
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
|
@ -671,29 +678,21 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
checkpoint: The checkpoint state dictionary
|
||||
filepath: write-target file's path
|
||||
"""
|
||||
if self.world_size > 1 and self.zero_stage_3:
|
||||
if self.save_full_weights:
|
||||
# todo: expose this as general function in deepspeed
|
||||
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
|
||||
if self.is_global_zero:
|
||||
# State dict keys will include reference to wrapper LightningDeepSpeedModule
|
||||
# Delete `module` prefix before saving.
|
||||
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
|
||||
checkpoint["state_dict"] = state_dict
|
||||
return super().save_checkpoint(checkpoint, filepath)
|
||||
return
|
||||
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
|
||||
# todo (sean): Add link to docs once docs are merged.
|
||||
warning_cache.warn(
|
||||
"When saving the DeepSpeed Stage 3 checkpoint, "
|
||||
"each worker will save a shard of the checkpoint within a directory. "
|
||||
"If a single file is required after training, see <TODO> for instructions."
|
||||
)
|
||||
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
|
||||
# dump states as a checkpoint dictionary object
|
||||
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
|
||||
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
|
||||
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
|
||||
|
||||
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
|
||||
# dump states as a checkpoint dictionary object
|
||||
save_dir = self._filepath_to_dir(filepath)
|
||||
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
|
||||
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
|
||||
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
|
||||
else:
|
||||
super().save_checkpoint(checkpoint, filepath)
|
||||
|
||||
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
if self.save_full_weights or self.world_size == 1:
|
||||
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
|
||||
if self.load_full_weights and self.zero_stage_3:
|
||||
# Broadcast to ensure we load from the rank 0 checkpoint
|
||||
# This doesn't have to be the case when using deepspeed sharded checkpointing
|
||||
checkpoint_path = self.broadcast(checkpoint_path)
|
||||
|
@ -703,20 +702,78 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
||||
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
|
||||
save_dir = self._filepath_to_dir(checkpoint_path)
|
||||
|
||||
if self.zero_stage_3:
|
||||
# TODO: Currently required as this call is missing within the deepspeed engine.
|
||||
self.deepspeed_engine.optimizer._partition_all_parameters()
|
||||
|
||||
_, client_state = self.deepspeed_engine.load_checkpoint(
|
||||
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
|
||||
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
|
||||
)
|
||||
if client_state is None:
|
||||
raise MisconfigurationException(
|
||||
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
|
||||
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`."
|
||||
)
|
||||
return client_state
|
||||
|
||||
@property
|
||||
def lightning_restore_optimizer_and_schedulers(self) -> bool:
|
||||
# managed by DeepSpeed
|
||||
if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
|
||||
rank_zero_warn(
|
||||
"A single checkpoint file has been given. This means optimizer states and "
|
||||
"scheduler states can not be restored. If you'd like to restore these states, you must "
|
||||
"provide a path to the originally saved DeepSpeed checkpoint."
|
||||
)
|
||||
return False
|
||||
|
||||
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
|
||||
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
|
||||
pass
|
||||
if self.load_full_weights and self.zero_stage_3:
|
||||
self.model_to_device()
|
||||
self._restore_zero_state(checkpoint)
|
||||
|
||||
def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Overrides the normal load_state_dict behaviour in PyTorch to ensure
|
||||
we gather parameters that may be sharded across processes before loading
|
||||
the state dictionary when using ZeRO stage 3.
|
||||
This is then automatically synced across processes.
|
||||
|
||||
Args:
|
||||
ckpt: The ckpt file.
|
||||
"""
|
||||
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
state_dict = ckpt["state_dict"]
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
||||
if self.is_global_zero:
|
||||
module._load_from_state_dict(
|
||||
state_dict=state_dict,
|
||||
prefix=prefix,
|
||||
local_metadata=local_metadata,
|
||||
strict=True,
|
||||
missing_keys=missing_keys,
|
||||
unexpected_keys=unexpected_keys,
|
||||
error_msgs=error_msgs,
|
||||
)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(self.lightning_module, prefix="")
|
||||
|
||||
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
|
||||
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
|
||||
|
|
|
@ -231,7 +231,9 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
Override to delay setting optimizers and schedulers till after dispatch.
|
||||
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
|
||||
However this may break certain precision plugins such as APEX which require optimizers to be set.
|
||||
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
|
||||
|
||||
Returns:
|
||||
If True, delay setup optimizers till pre_dispatch, else call within setup.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
|
|
@ -16,10 +16,14 @@ from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelChec
|
|||
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
|
||||
from tests.helpers.datamodules import ClassifDataModule
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
|
||||
class ModelParallelBoringModel(BoringModel):
|
||||
def __init__(self):
|
||||
|
@ -329,8 +333,6 @@ def test_deepspeed_config(tmpdir, deepspeed_zero_config):
|
|||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=True)
|
||||
def test_deepspeed_custom_precision_params(tmpdir):
|
||||
|
@ -394,17 +396,13 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_co
|
|||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
def test_deepspeed_multigpu(tmpdir, deepspeed_config):
|
||||
def test_deepspeed_multigpu(tmpdir):
|
||||
"""
|
||||
Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation.
|
||||
Test to ensure that DeepSpeed with multiple GPUs works.
|
||||
"""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
plugins=[DeepSpeedPlugin(zero_optimization=False, stage=2)],
|
||||
gpus=2,
|
||||
fast_dev_run=True,
|
||||
precision=16,
|
||||
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
@ -419,6 +417,54 @@ def test_deepspeed_fp32_works(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
def test_deepspeed_stage_3_save_warning(tmpdir):
|
||||
"""
|
||||
Test to ensure that DeepSpeed Stage 3 gives a warning when saving.
|
||||
"""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
|
||||
)
|
||||
trainer.fit(model)
|
||||
checkpoint_path = os.path.join(tmpdir, "model.pt")
|
||||
with pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory."):
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=True)
|
||||
def test_deepspeed_multigpu_single_file(tmpdir):
|
||||
"""
|
||||
Test to ensure that DeepSpeed loads from a single file checkpoint.
|
||||
"""
|
||||
model = BoringModel()
|
||||
checkpoint_path = os.path.join(tmpdir, "model.pt")
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16
|
||||
)
|
||||
plugin = trainer.training_type_plugin
|
||||
assert isinstance(plugin, DeepSpeedPlugin)
|
||||
assert not plugin.load_full_weights
|
||||
with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."):
|
||||
trainer.test(model, ckpt_path=checkpoint_path)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
plugins=[DeepSpeedPlugin(stage=3, load_full_weights=True)],
|
||||
gpus=1,
|
||||
fast_dev_run=True,
|
||||
precision=16,
|
||||
)
|
||||
plugin = trainer.training_type_plugin
|
||||
assert isinstance(plugin, DeepSpeedPlugin)
|
||||
assert plugin.load_full_weights
|
||||
trainer.test(model, ckpt_path=checkpoint_path)
|
||||
|
||||
|
||||
class ModelParallelClassificationModel(LightningModule):
|
||||
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
|
||||
super().__init__()
|
||||
|
@ -474,6 +520,10 @@ class ModelParallelClassificationModel(LightningModule):
|
|||
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
|
||||
return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "model"):
|
||||
self.configure_sharded_model()
|
||||
|
||||
|
||||
class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
|
||||
@property
|
||||
|
@ -504,7 +554,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
|
|||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
|
@ -520,12 +570,10 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config
|
|||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim)
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer)
|
||||
|
||||
|
||||
def run_checkpoint_test(
|
||||
tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2
|
||||
):
|
||||
def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2):
|
||||
seed_everything(1)
|
||||
if automatic_optimization:
|
||||
model = ModelParallelClassificationModel()
|
||||
|
@ -535,9 +583,8 @@ def run_checkpoint_test(
|
|||
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=10,
|
||||
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
|
||||
plugins=[DeepSpeedPlugin(stage=3)],
|
||||
gpus=2,
|
||||
precision=16,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
|
@ -545,29 +592,20 @@ def run_checkpoint_test(
|
|||
)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
results = trainer.test(model, datamodule=dm)
|
||||
results = trainer.test(datamodule=dm)
|
||||
assert results[0]["test_acc"] > 0.7
|
||||
|
||||
saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
|
||||
assert saved_results[0]["test_acc"] > 0.7
|
||||
assert saved_results == results
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=10,
|
||||
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
|
||||
gpus=2,
|
||||
precision=16,
|
||||
accumulate_grad_batches=2,
|
||||
callbacks=[ck],
|
||||
resume_from_checkpoint=ck.best_model_path,
|
||||
)
|
||||
results = trainer.test(model, datamodule=dm)
|
||||
assert results[0]["test_acc"] > 0.7
|
||||
if automatic_optimization:
|
||||
model = ModelParallelClassificationModel()
|
||||
else:
|
||||
model = ManualModelParallelClassificationModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, gpus=2, plugins=[DeepSpeedPlugin(stage=3)], precision=16)
|
||||
|
||||
dm.predict_dataloader = dm.test_dataloader
|
||||
results = trainer.predict(datamodule=dm)
|
||||
assert results[-1] > 0.7
|
||||
results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path)
|
||||
assert results[0]["test_acc"] > 0.7
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
|
@ -576,16 +614,94 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
|
|||
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
|
||||
and see convergence.
|
||||
"""
|
||||
run_checkpoint_test(tmpdir, save_full_weights=False)
|
||||
run_checkpoint_test(tmpdir)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=False)
|
||||
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
|
||||
"""
|
||||
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
|
||||
where we save the full weights to one file.
|
||||
Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning
|
||||
that the optimizer state and scheduler states cannot be restored.
|
||||
"""
|
||||
run_checkpoint_test(tmpdir, save_full_weights=True)
|
||||
dm = ClassifDataModule()
|
||||
model = BoringModel()
|
||||
checkpoint_path = os.path.join(tmpdir, "model.pt")
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
plugins=DeepSpeedPlugin(stage=3, load_full_weights=True),
|
||||
gpus=1,
|
||||
precision=16,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
)
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="A single checkpoint file has been given. This means optimizer states and "
|
||||
"scheduler states can not be restored. If you'd like to restore these states, you must "
|
||||
"provide a path to the originally saved DeepSpeed checkpoint.",
|
||||
):
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=True)
|
||||
def test_deepspeed_multigpu_stage_3_resume_training(tmpdir):
|
||||
"""
|
||||
Test to ensure with Stage 3 and multiple GPUs that we can resume training.
|
||||
"""
|
||||
initial_model = ModelParallelClassificationModel()
|
||||
dm = ClassifDataModule()
|
||||
|
||||
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
|
||||
initial_trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
plugins=DeepSpeedPlugin(stage=3),
|
||||
gpus=1,
|
||||
precision=16,
|
||||
callbacks=[ck],
|
||||
)
|
||||
initial_trainer.fit(initial_model, datamodule=dm)
|
||||
|
||||
class TestCallback(Callback):
|
||||
def on_train_batch_start(
|
||||
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin
|
||||
current_deepspeed_plugin = trainer.accelerator.training_type_plugin
|
||||
|
||||
assert isinstance(original_deepspeed_plugin, DeepSpeedPlugin)
|
||||
assert isinstance(current_deepspeed_plugin, DeepSpeedPlugin)
|
||||
# assert optimizer states are the correctly loaded
|
||||
original_optimizer_dict = original_deepspeed_plugin.deepspeed_engine.optimizer.state_dict()
|
||||
current_optimizer_dict = current_deepspeed_plugin.deepspeed_engine.optimizer.state_dict()
|
||||
for orig_tensor, current_tensor in zip(
|
||||
original_optimizer_dict["fp32_flat_groups"], current_optimizer_dict["fp32_flat_groups"]
|
||||
):
|
||||
assert torch.all(orig_tensor.eq(current_tensor))
|
||||
# assert model state is loaded correctly
|
||||
for current_param, initial_param in zip(pl_module.parameters(), initial_model.parameters()):
|
||||
assert torch.equal(current_param.cpu(), initial_param.cpu())
|
||||
# assert epoch has correctly been restored
|
||||
assert trainer.current_epoch == 1
|
||||
|
||||
model = ModelParallelClassificationModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
plugins=DeepSpeedPlugin(stage=3),
|
||||
gpus=1,
|
||||
precision=16,
|
||||
resume_from_checkpoint=ck.best_model_path,
|
||||
callbacks=TestCallback(),
|
||||
)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
|
@ -594,7 +710,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir):
|
|||
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
|
||||
where we save the full weights to one file.
|
||||
"""
|
||||
run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1)
|
||||
run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
|
@ -680,18 +796,25 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat
|
|||
assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank)
|
||||
|
||||
|
||||
def _assert_save_model_is_equal(model, tmpdir, trainer, cls=BoringModel):
|
||||
def _assert_save_model_is_equal(model, tmpdir, trainer):
|
||||
checkpoint_path = os.path.join(tmpdir, "model.pt")
|
||||
checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
trainer.accelerator.barrier()
|
||||
|
||||
# carry out the check only on rank 0
|
||||
if trainer.global_rank == 0:
|
||||
saved_model = cls.load_from_checkpoint(checkpoint_path)
|
||||
if model.dtype == torch.half:
|
||||
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16
|
||||
if trainer.is_global_zero:
|
||||
single_ckpt_path = os.path.join(tmpdir, "single_model.pt")
|
||||
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
|
||||
state_dict = torch.load(single_ckpt_path)
|
||||
|
||||
model = model.cpu()
|
||||
# Assert model parameters are identical after loading
|
||||
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
|
||||
assert torch.equal(orig_param, trained_model_param)
|
||||
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
|
||||
if model.dtype == torch.half:
|
||||
# moved model to float32 for comparison with single fp32 saved weights
|
||||
saved_model_param = saved_model_param.half()
|
||||
assert torch.equal(orig_param, saved_model_param)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, deepspeed=True, special=True)
|
||||
|
@ -705,4 +828,4 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelNoSchedulers)
|
||||
_assert_save_model_is_equal(model, tmpdir, trainer)
|
||||
|
|
Loading…
Reference in New Issue