Add support for async checkpointing (#13658)
This commit is contained in:
parent
9c720c8adf
commit
faf7ff57c0
|
@ -210,6 +210,7 @@ io
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
AsyncCheckpointIO
|
||||||
CheckpointIO
|
CheckpointIO
|
||||||
HPUCheckpointIO
|
HPUCheckpointIO
|
||||||
TorchCheckpointIO
|
TorchCheckpointIO
|
||||||
|
|
|
@ -45,6 +45,10 @@ Built-in Checkpoint IO Plugins
|
||||||
respectively, common for most use cases.
|
respectively, common for most use cases.
|
||||||
* - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO`
|
* - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO`
|
||||||
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
|
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
|
||||||
|
* - :class:`~pytorch_lightning.plugins.io.HPUCheckpointIO`
|
||||||
|
- CheckpointIO to save checkpoints for HPU training strategies.
|
||||||
|
* - :class:`~pytorch_lightning.plugins.io.AsyncCheckpointIO`
|
||||||
|
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
|
||||||
|
|
||||||
|
|
||||||
***************************
|
***************************
|
||||||
|
@ -94,3 +98,36 @@ Custom Checkpoint IO Plugin
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
|
Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
|
||||||
|
|
||||||
|
|
||||||
|
**************************
|
||||||
|
Asynchronous Checkpointing
|
||||||
|
**************************
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is currently an experimental plugin/feature and API changes are to be expected.
|
||||||
|
|
||||||
|
To enable saving the checkpoints asynchronously without blocking your training, you can configure
|
||||||
|
:class:`~pytorch_lightning.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.io import AsyncCheckpointIO
|
||||||
|
|
||||||
|
|
||||||
|
async_ckpt_io = AsyncCheckpointIO()
|
||||||
|
trainer = Trainer(plugins=[async_ckpt_io])
|
||||||
|
|
||||||
|
|
||||||
|
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
|
||||||
|
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
|
||||||
|
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.io import AsyncCheckpointIO
|
||||||
|
|
||||||
|
base_ckpt_io = MyCustomCheckpointIO()
|
||||||
|
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
|
||||||
|
trainer = Trainer(plugins=[async_ckpt_io])
|
||||||
|
|
|
@ -87,6 +87,7 @@ Below is a list of built-in plugins for checkpointing.
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
AsyncCheckpointIO
|
||||||
CheckpointIO
|
CheckpointIO
|
||||||
HPUCheckpointIO
|
HPUCheckpointIO
|
||||||
TorchCheckpointIO
|
TorchCheckpointIO
|
||||||
|
|
|
@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405))
|
- Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405))
|
||||||
|
|
||||||
|
|
||||||
|
- Added support for async checkpointing ([#13658](https://github.com/PyTorchLightning/pytorch-lightning/pull/13658))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))
|
- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from pytorch_lightning.plugins.environments import ClusterEnvironment
|
from pytorch_lightning.plugins.environments import ClusterEnvironment
|
||||||
|
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||||
|
@ -38,6 +39,7 @@ PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, Laye
|
||||||
PLUGIN_INPUT = Union[PLUGIN, str]
|
PLUGIN_INPUT = Union[PLUGIN, str]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AsyncCheckpointIO",
|
||||||
"CheckpointIO",
|
"CheckpointIO",
|
||||||
"TorchCheckpointIO",
|
"TorchCheckpointIO",
|
||||||
"XLACheckpointIO",
|
"XLACheckpointIO",
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
|
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
|
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
|
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||||
|
|
||||||
|
__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"]
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCheckpointIO(_WrappingCheckpointIO):
|
||||||
|
"""``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
This is currently an experimental plugin/feature and API changes are to be expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
|
||||||
|
super().__init__(checkpoint_io)
|
||||||
|
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self._error: Optional[BaseException] = None
|
||||||
|
|
||||||
|
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
|
||||||
|
|
||||||
|
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
|
||||||
|
try:
|
||||||
|
assert self.checkpoint_io is not None
|
||||||
|
self.checkpoint_io.save_checkpoint(*args, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
self._error = e
|
||||||
|
|
||||||
|
self._executor.submit(_save_checkpoint, *args, **kwargs)
|
||||||
|
|
||||||
|
# if an error was raised between the previous time `save_checkpoint`` was called and now,
|
||||||
|
# because `executor.submit` is not blocking
|
||||||
|
if self._error:
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
def teardown(self) -> None:
|
||||||
|
"""This method is called to close the threads."""
|
||||||
|
self._executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
# if an error was raised anytime in any of the `executor.submit` calls
|
||||||
|
if self._error:
|
||||||
|
raise self._error
|
|
@ -43,12 +43,13 @@ class CheckpointIO(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
|
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]:
|
||||||
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
|
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to checkpoint
|
path: Path to checkpoint
|
||||||
storage_options: Optional parameters when loading the model/training states.
|
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
|
||||||
|
locations.
|
||||||
|
|
||||||
Returns: The loaded checkpoint.
|
Returns: The loaded checkpoint.
|
||||||
"""
|
"""
|
||||||
|
@ -60,3 +61,6 @@ class CheckpointIO(ABC):
|
||||||
Args:
|
Args:
|
||||||
path: Path to checkpoint
|
path: Path to checkpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def teardown(self) -> None:
|
||||||
|
"""This method is called to teardown the process."""
|
||||||
|
|
|
@ -69,7 +69,7 @@ class TorchCheckpointIO(CheckpointIO):
|
||||||
Args:
|
Args:
|
||||||
path: Path to checkpoint
|
path: Path to checkpoint
|
||||||
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
|
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
|
||||||
locations.
|
locations.
|
||||||
|
|
||||||
Returns: The loaded checkpoint.
|
Returns: The loaded checkpoint.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
|
||||||
|
|
||||||
|
class _WrappingCheckpointIO(CheckpointIO):
|
||||||
|
"""``_WrappingCheckpointIO`` is a wrapper checkpoint_io that uses a base checkpoint_io to handle checkpointing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_io: A checkpoint IO plugin that is used as the basis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._checkpoint_io = checkpoint_io
|
||||||
|
self._base_checkpoint_io_configured: bool = False
|
||||||
|
|
||||||
|
if checkpoint_io is not None:
|
||||||
|
if isinstance(checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._base_checkpoint_io_configured = checkpoint_io._base_checkpoint_io_configured
|
||||||
|
else:
|
||||||
|
self._base_checkpoint_io_configured = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def checkpoint_io(self) -> Optional["CheckpointIO"]:
|
||||||
|
return self._checkpoint_io
|
||||||
|
|
||||||
|
@checkpoint_io.setter
|
||||||
|
def checkpoint_io(self, checkpoint_io: "CheckpointIO") -> None:
|
||||||
|
assert not isinstance(checkpoint_io, _WrappingCheckpointIO)
|
||||||
|
|
||||||
|
if self._checkpoint_io is None:
|
||||||
|
self._base_checkpoint_io_configured = True
|
||||||
|
self._checkpoint_io = checkpoint_io
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO) and not self._base_checkpoint_io_configured:
|
||||||
|
self._base_checkpoint_io_configured = True
|
||||||
|
self._checkpoint_io.checkpoint_io = checkpoint_io
|
||||||
|
|
||||||
|
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Uses the base ``checkpoint_io`` to save the checkpoint."""
|
||||||
|
assert self.checkpoint_io is not None
|
||||||
|
self.checkpoint_io.save_checkpoint(*args, **kwargs)
|
||||||
|
|
||||||
|
def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Uses the base ``checkpoint_io`` to remove the checkpoint."""
|
||||||
|
assert self.checkpoint_io is not None
|
||||||
|
self.checkpoint_io.remove_checkpoint(*args, **kwargs)
|
||||||
|
|
||||||
|
def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Uses the base ``checkpoint_io`` to load the checkpoint."""
|
||||||
|
assert self.checkpoint_io is not None
|
||||||
|
return self.checkpoint_io.load_checkpoint(*args, **kwargs)
|
|
@ -23,6 +23,7 @@ from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||||
from pytorch_lightning.utilities.distributed import group as _group
|
from pytorch_lightning.utilities.distributed import group as _group
|
||||||
|
@ -78,6 +79,9 @@ class HPUParallelStrategy(DDPStrategy):
|
||||||
def checkpoint_io(self) -> CheckpointIO:
|
def checkpoint_io(self) -> CheckpointIO:
|
||||||
if self._checkpoint_io is None:
|
if self._checkpoint_io is None:
|
||||||
self._checkpoint_io = HPUCheckpointIO()
|
self._checkpoint_io = HPUCheckpointIO()
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
|
||||||
|
|
||||||
return self._checkpoint_io
|
return self._checkpoint_io
|
||||||
|
|
||||||
@checkpoint_io.setter
|
@checkpoint_io.setter
|
||||||
|
|
|
@ -17,6 +17,7 @@ from typing import Dict, Optional
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
||||||
from pytorch_lightning.utilities import _HPU_AVAILABLE
|
from pytorch_lightning.utilities import _HPU_AVAILABLE
|
||||||
|
@ -54,6 +55,9 @@ class SingleHPUStrategy(SingleDeviceStrategy):
|
||||||
def checkpoint_io(self) -> CheckpointIO:
|
def checkpoint_io(self) -> CheckpointIO:
|
||||||
if self._checkpoint_io is None:
|
if self._checkpoint_io is None:
|
||||||
self._checkpoint_io = HPUCheckpointIO()
|
self._checkpoint_io = HPUCheckpointIO()
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
|
||||||
|
|
||||||
return self._checkpoint_io
|
return self._checkpoint_io
|
||||||
|
|
||||||
@checkpoint_io.setter
|
@checkpoint_io.setter
|
||||||
|
|
|
@ -16,6 +16,7 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
||||||
|
@ -50,6 +51,9 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
||||||
def checkpoint_io(self) -> CheckpointIO:
|
def checkpoint_io(self) -> CheckpointIO:
|
||||||
if self._checkpoint_io is None:
|
if self._checkpoint_io is None:
|
||||||
self._checkpoint_io = XLACheckpointIO()
|
self._checkpoint_io = XLACheckpointIO()
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
|
||||||
|
|
||||||
return self._checkpoint_io
|
return self._checkpoint_io
|
||||||
|
|
||||||
@checkpoint_io.setter
|
@checkpoint_io.setter
|
||||||
|
|
|
@ -27,6 +27,7 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers,
|
||||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||||
from pytorch_lightning.plugins import TorchCheckpointIO
|
from pytorch_lightning.plugins import TorchCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||||
from pytorch_lightning.trainer.states import TrainerFn
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
|
@ -84,6 +85,8 @@ class Strategy(ABC):
|
||||||
def checkpoint_io(self) -> CheckpointIO:
|
def checkpoint_io(self) -> CheckpointIO:
|
||||||
if self._checkpoint_io is None:
|
if self._checkpoint_io is None:
|
||||||
self._checkpoint_io = TorchCheckpointIO()
|
self._checkpoint_io = TorchCheckpointIO()
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._checkpoint_io.checkpoint_io = TorchCheckpointIO()
|
||||||
|
|
||||||
return self._checkpoint_io
|
return self._checkpoint_io
|
||||||
|
|
||||||
|
@ -467,6 +470,7 @@ class Strategy(ABC):
|
||||||
self.precision_plugin.teardown()
|
self.precision_plugin.teardown()
|
||||||
assert self.accelerator is not None
|
assert self.accelerator is not None
|
||||||
self.accelerator.teardown()
|
self.accelerator.teardown()
|
||||||
|
self.checkpoint_io.teardown()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
|
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
|
||||||
|
|
|
@ -24,6 +24,7 @@ import pytorch_lightning as pl
|
||||||
from pytorch_lightning.overrides import LightningDistributedModule
|
from pytorch_lightning.overrides import LightningDistributedModule
|
||||||
from pytorch_lightning.plugins.environments import XLAEnvironment
|
from pytorch_lightning.plugins.environments import XLAEnvironment
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||||
|
@ -78,6 +79,9 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
||||||
def checkpoint_io(self) -> CheckpointIO:
|
def checkpoint_io(self) -> CheckpointIO:
|
||||||
if self._checkpoint_io is None:
|
if self._checkpoint_io is None:
|
||||||
self._checkpoint_io = XLACheckpointIO()
|
self._checkpoint_io = XLACheckpointIO()
|
||||||
|
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
|
||||||
|
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
|
||||||
|
|
||||||
return self._checkpoint_io
|
return self._checkpoint_io
|
||||||
|
|
||||||
@checkpoint_io.setter
|
@checkpoint_io.setter
|
||||||
|
|
|
@ -62,7 +62,6 @@ def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
|
||||||
filepath: The path to which the checkpoint will be saved.
|
filepath: The path to which the checkpoint will be saved.
|
||||||
This points to the file that the checkpoint will be stored in.
|
This points to the file that the checkpoint will be stored in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bytesbuffer = io.BytesIO()
|
bytesbuffer = io.BytesIO()
|
||||||
torch.save(checkpoint, bytesbuffer)
|
torch.save(checkpoint, bytesbuffer)
|
||||||
with fsspec.open(filepath, "wb") as f:
|
with fsspec.open(filepath, "wb") as f:
|
||||||
|
|
|
@ -12,15 +12,18 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||||
from pytorch_lightning.plugins import CheckpointIO
|
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
|
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||||
from pytorch_lightning.strategies import SingleDeviceStrategy
|
from pytorch_lightning.strategies import SingleDeviceStrategy
|
||||||
from pytorch_lightning.utilities.types import _PATH
|
from pytorch_lightning.utilities.types import _PATH
|
||||||
|
|
||||||
|
@ -49,9 +52,16 @@ def test_checkpoint_plugin_called(tmpdir):
|
||||||
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
|
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
|
||||||
callbacks=ck,
|
callbacks=ck,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=0,
|
||||||
|
limit_test_batches=1,
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
|
||||||
|
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
|
||||||
|
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
|
||||||
|
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
|
||||||
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
|
@ -68,12 +78,76 @@ def test_checkpoint_plugin_called(tmpdir):
|
||||||
plugins=[checkpoint_plugin],
|
plugins=[checkpoint_plugin],
|
||||||
callbacks=ck,
|
callbacks=ck,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=0,
|
||||||
|
limit_test_batches=1,
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
|
||||||
|
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
|
||||||
|
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
|
||||||
|
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
|
||||||
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
assert checkpoint_plugin.save_checkpoint.call_count == 4
|
||||||
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
trainer.test(model, ckpt_path=ck.last_model_path)
|
trainer.test(model, ckpt_path=ck.last_model_path)
|
||||||
checkpoint_plugin.load_checkpoint.assert_called_once()
|
checkpoint_plugin.load_checkpoint.assert_called_once()
|
||||||
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")
|
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_checkpoint_plugin(tmpdir):
|
||||||
|
"""Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when async saving and
|
||||||
|
loading."""
|
||||||
|
|
||||||
|
checkpoint_plugin = AsyncCheckpointIO()
|
||||||
|
|
||||||
|
checkpoint_plugin.save_checkpoint = Mock(wraps=checkpoint_plugin.save_checkpoint)
|
||||||
|
checkpoint_plugin.remove_checkpoint = Mock(wraps=checkpoint_plugin.remove_checkpoint)
|
||||||
|
|
||||||
|
class CustomBoringModel(BoringModel):
|
||||||
|
def on_fit_start(self):
|
||||||
|
base_ckpt_io = self.trainer.strategy.checkpoint_io.checkpoint_io
|
||||||
|
base_ckpt_io.save_checkpoint = Mock(wraps=base_ckpt_io.save_checkpoint)
|
||||||
|
base_ckpt_io.remove_checkpoint = Mock(wraps=base_ckpt_io.remove_checkpoint)
|
||||||
|
|
||||||
|
ck = ModelCheckpoint(dirpath=tmpdir, save_top_k=2, monitor="step", mode="max")
|
||||||
|
|
||||||
|
model = CustomBoringModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
plugins=[checkpoint_plugin],
|
||||||
|
callbacks=ck,
|
||||||
|
max_epochs=3,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=0,
|
||||||
|
enable_progress_bar=False,
|
||||||
|
enable_model_summary=False,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
assert checkpoint_plugin.save_checkpoint.call_count == 3
|
||||||
|
assert checkpoint_plugin.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
|
base_ckpt_io = trainer.strategy.checkpoint_io.checkpoint_io
|
||||||
|
assert base_ckpt_io.save_checkpoint.call_count == 3
|
||||||
|
assert base_ckpt_io.remove_checkpoint.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_wrapped_checkpoint_io_initialization():
|
||||||
|
base_ckpt_io = TorchCheckpointIO()
|
||||||
|
wrap_ckpt = AsyncCheckpointIO(base_ckpt_io)
|
||||||
|
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
|
||||||
|
assert ckpt_io.checkpoint_io is wrap_ckpt
|
||||||
|
assert ckpt_io.checkpoint_io.checkpoint_io is base_ckpt_io
|
||||||
|
assert ckpt_io._base_checkpoint_io_configured is True
|
||||||
|
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True
|
||||||
|
|
||||||
|
wrap_ckpt = AsyncCheckpointIO()
|
||||||
|
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
|
||||||
|
trainer = Trainer(accelerator="cpu", plugins=[ckpt_io])
|
||||||
|
trainer.strategy.checkpoint_io
|
||||||
|
assert ckpt_io.checkpoint_io is wrap_ckpt
|
||||||
|
assert isinstance(ckpt_io.checkpoint_io.checkpoint_io, TorchCheckpointIO)
|
||||||
|
assert ckpt_io._base_checkpoint_io_configured is True
|
||||||
|
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True
|
||||||
|
|
Loading…
Reference in New Issue