[refactor] Move save_function to accelerator 1/n [DeepSpeed] (#6689)

* move save_checkpoint responsability to accelerator

* update
This commit is contained in:
thomas chaton 2021-03-29 20:02:37 +01:00 committed by GitHub
parent 3a4c4246ee
commit 646cf2f7d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 24 deletions

View File

@ -466,3 +466,6 @@ class Accelerator(object):
' It will be removed in v1.5.'
)
self.setup_precision_plugin(plugin)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
self.training_type_plugin.save_checkpoint(checkpoint, filepath)

View File

@ -106,8 +106,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)
# replace trainer save_checkpoint to use `xm.save`
trainer.save_checkpoint = self.save_checkpoint
self.barrier("pre-run-stage")
results = trainer.run_stage()
@ -298,14 +296,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def predict_step(self, *args, **kwargs):
return self.lightning_module.predict_step(*args, **kwargs)
def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
trainer: PyTorch Lightning Trainer
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

View File

@ -22,6 +22,8 @@ from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer
@ -192,3 +194,19 @@ class TrainingTypePlugin(Plugin, ABC):
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
# dump states as a checkpoint dictionary object
if self.is_global_zero:
checkpoint = self.on_save(checkpoint)
try:
# write the checkpoint dictionary on the file
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)

View File

@ -386,27 +386,12 @@ class CheckpointConnector:
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'
def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
checkpoint = self.dump_checkpoint(weights_only)
if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file
if self.trainer.training_type_plugin:
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)