[refactor] Move save_function to accelerator 1/n [DeepSpeed] (#6689)
* move save_checkpoint responsability to accelerator * update
This commit is contained in:
parent
3a4c4246ee
commit
646cf2f7d4
|
@ -466,3 +466,6 @@ class Accelerator(object):
|
|||
' It will be removed in v1.5.'
|
||||
)
|
||||
self.setup_precision_plugin(plugin)
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
|
||||
self.training_type_plugin.save_checkpoint(checkpoint, filepath)
|
||||
|
|
|
@ -106,8 +106,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
trainer.accelerator.setup_optimizers(trainer)
|
||||
trainer.precision_plugin.connect(self._model, None, None)
|
||||
|
||||
# replace trainer save_checkpoint to use `xm.save`
|
||||
trainer.save_checkpoint = self.save_checkpoint
|
||||
self.barrier("pre-run-stage")
|
||||
|
||||
results = trainer.run_stage()
|
||||
|
@ -298,14 +296,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def predict_step(self, *args, **kwargs):
|
||||
return self.lightning_module.predict_step(*args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False):
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
trainer: PyTorch Lightning Trainer
|
||||
filepath: write-target file's path
|
||||
weights_only: saving model weights only
|
||||
"""
|
||||
# dump states as a checkpoint dictionary object
|
||||
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
|
||||
# Todo: TypeError: 'mappingproxy' object does not support item assignment
|
||||
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
|
||||
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
|
||||
|
|
|
@ -22,6 +22,8 @@ from torch.utils.data import DataLoader
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.plugins.base_plugin import Plugin
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
@ -192,3 +194,19 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
|
||||
"""
|
||||
return False
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
# dump states as a checkpoint dictionary object
|
||||
if self.is_global_zero:
|
||||
checkpoint = self.on_save(checkpoint)
|
||||
try:
|
||||
# write the checkpoint dictionary on the file
|
||||
atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(
|
||||
'Warning, `hyper_parameters` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}'
|
||||
)
|
||||
atomic_save(checkpoint, filepath)
|
||||
|
|
|
@ -386,27 +386,12 @@ class CheckpointConnector:
|
|||
ckpt_number = max_suffix if max_suffix is not None else 0
|
||||
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'
|
||||
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False):
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
filepath: write-target file's path
|
||||
weights_only: saving model weights only
|
||||
"""
|
||||
# dump states as a checkpoint dictionary object
|
||||
checkpoint = self.dump_checkpoint(weights_only)
|
||||
if self.trainer.is_global_zero:
|
||||
# write the checkpoint dictionary on the file
|
||||
|
||||
if self.trainer.training_type_plugin:
|
||||
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
|
||||
try:
|
||||
atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(
|
||||
'Warning, `hyper_parameters` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}'
|
||||
)
|
||||
atomic_save(checkpoint, filepath)
|
||||
_checkpoint = self.dump_checkpoint(weights_only)
|
||||
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)
|
||||
|
|
Loading…
Reference in New Issue