Introduce `Stateful` PrecisionPlugin (#11638)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
914f685ed8
commit
d69b33f1f0
|
@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
|
||||
|
||||
|
||||
- Added `_Stateful` support for `PrecisionPlugin` ([#11638](https://github.com/PyTorchLightning/pytorch-lightning/pull/11638))
|
||||
|
||||
|
||||
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
|
||||
|
||||
|
||||
|
|
|
@ -93,9 +93,21 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
return optimizer.step(**kwargs)
|
||||
return closure_result
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return amp.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
amp.load_state_dict(state_dict)
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
if "amp_scaling_state" in checkpoint:
|
||||
amp.load_state_dict(checkpoint["amp_scaling_state"])
|
||||
"""``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6.
|
||||
|
||||
Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict``
|
||||
instead
|
||||
"""
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
checkpoint["amp_scaling_state"] = amp.state_dict()
|
||||
"""``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6.
|
||||
|
||||
Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead
|
||||
"""
|
||||
|
|
|
@ -108,10 +108,24 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
with self.autocast_context_manager():
|
||||
yield
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
if self.scaler is not None:
|
||||
return self.scaler.state_dict()
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
if self.scaler is not None:
|
||||
self.scaler.load_state_dict(state_dict)
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
if self.scaler is not None and "native_amp_scaling_state" in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])
|
||||
"""``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6.
|
||||
|
||||
Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict``
|
||||
instead
|
||||
"""
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
if self.scaler is not None:
|
||||
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
|
||||
"""``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6.
|
||||
|
||||
Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -242,3 +242,20 @@ class PrecisionPlugin(CheckpointHooks):
|
|||
|
||||
It is the right place to release memory and free other resources.
|
||||
"""
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""Called when saving a checkpoint, implement to generate precision plugin state_dict.
|
||||
|
||||
Returns:
|
||||
A dictionary containing precision plugin state.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
"""Called when loading a checkpoint, implement to reload precision plugin state given precision plugin
|
||||
state_dict.
|
||||
|
||||
Args:
|
||||
state_dict: the precision plugin state returned by ``state_dict``.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -22,6 +22,7 @@ from torchmetrics import Metric
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
@ -196,7 +197,7 @@ class CheckpointConnector:
|
|||
return
|
||||
|
||||
# restore precision plugin (scaler etc.)
|
||||
self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
|
||||
self.restore_precision_plugin_state()
|
||||
|
||||
# restore loops and their progress
|
||||
self.restore_loops()
|
||||
|
@ -206,6 +207,21 @@ class CheckpointConnector:
|
|||
# restore optimizers and schedulers state
|
||||
self.restore_optimizers_and_schedulers()
|
||||
|
||||
def restore_precision_plugin_state(self) -> None:
|
||||
"""Restore the precision plugin state from the pre-loaded checkpoint."""
|
||||
prec_plugin = self.trainer.precision_plugin
|
||||
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
|
||||
if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint:
|
||||
prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__])
|
||||
|
||||
# old checkpoints compatibility
|
||||
if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin):
|
||||
prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"])
|
||||
if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(
|
||||
prec_plugin, NativeMixedPrecisionPlugin
|
||||
):
|
||||
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
|
||||
|
||||
def restore_callbacks(self) -> None:
|
||||
"""Restores all callbacks from the pre-loaded checkpoint."""
|
||||
if not self._loaded_checkpoint:
|
||||
|
@ -318,9 +334,8 @@ class CheckpointConnector:
|
|||
'callbacks': "callback specific state"[] # if not weights_only
|
||||
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
|
||||
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
|
||||
'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
|
||||
'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
|
||||
'state_dict': Model's state_dict (e.g. network weights)
|
||||
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
|
||||
CHECKPOINT_HYPER_PARAMS_NAME:
|
||||
CHECKPOINT_HYPER_PARAMS_KEY:
|
||||
CHECKPOINT_HYPER_PARAMS_TYPE:
|
||||
|
@ -357,7 +372,12 @@ class CheckpointConnector:
|
|||
lr_schedulers.append(config.scheduler.state_dict())
|
||||
checkpoint["lr_schedulers"] = lr_schedulers
|
||||
|
||||
self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
|
||||
# precision plugin
|
||||
prec_plugin = self.trainer.precision_plugin
|
||||
prec_plugin_state_dict = prec_plugin.state_dict()
|
||||
if prec_plugin_state_dict:
|
||||
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
|
||||
prec_plugin.on_save_checkpoint(checkpoint)
|
||||
|
||||
# dump hyper-parameters
|
||||
if model.hparams:
|
||||
|
|
|
@ -493,10 +493,8 @@ def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization):
|
|||
"state_dict": ANY,
|
||||
"loops": ANY,
|
||||
}
|
||||
if kwargs.get("amp_backend") == "native":
|
||||
saved_ckpt["native_amp_scaling_state"] = ANY
|
||||
elif kwargs.get("amp_backend") == "apex":
|
||||
saved_ckpt["amp_scaling_state"] = ANY
|
||||
if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex":
|
||||
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
|
||||
device = torch.device("cuda:0" if "gpus" in kwargs else "cpu")
|
||||
expected = [
|
||||
dict(name="Callback.on_init_start", args=(trainer,)),
|
||||
|
|
Loading…
Reference in New Issue