Add support for async checkpointing (#13658)

This commit is contained in:
Rohit Gupta 2022-07-26 21:13:19 +05:30 committed by GitHub
parent 9c720c8adf
commit faf7ff57c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 282 additions and 10 deletions

View File

@ -210,6 +210,7 @@ io
:nosignatures:
:template: classtemplate.rst
AsyncCheckpointIO
CheckpointIO
HPUCheckpointIO
TorchCheckpointIO

View File

@ -45,6 +45,10 @@ Built-in Checkpoint IO Plugins
respectively, common for most use cases.
* - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO`
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
* - :class:`~pytorch_lightning.plugins.io.HPUCheckpointIO`
- CheckpointIO to save checkpoints for HPU training strategies.
* - :class:`~pytorch_lightning.plugins.io.AsyncCheckpointIO`
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
***************************
@ -94,3 +98,36 @@ Custom Checkpoint IO Plugin
.. note::
Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
**************************
Asynchronous Checkpointing
**************************
.. warning::
This is currently an experimental plugin/feature and API changes are to be expected.
To enable saving the checkpoints asynchronously without blocking your training, you can configure
:class:`~pytorch_lightning.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
.. code-block:: python
from pytorch_lightning.plugins.io import AsyncCheckpointIO
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
.. code-block:: python
from pytorch_lightning.plugins.io import AsyncCheckpointIO
base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])

View File

@ -87,6 +87,7 @@ Below is a list of built-in plugins for checkpointing.
:nosignatures:
:template: classtemplate.rst
AsyncCheckpointIO
CheckpointIO
HPUCheckpointIO
TorchCheckpointIO

View File

@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405))
- Added support for async checkpointing ([#13658](https://github.com/PyTorchLightning/pytorch-lightning/pull/13658))
### Changed
- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))

View File

@ -1,6 +1,7 @@
from typing import Union
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
@ -38,6 +39,7 @@ PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, Laye
PLUGIN_INPUT = Union[PLUGIN, str]
__all__ = [
"AsyncCheckpointIO",
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",

View File

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"]

View File

@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
class AsyncCheckpointIO(_WrappingCheckpointIO):
"""``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
.. warning::
This is currently an experimental plugin/feature and API changes are to be expected.
Args:
checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing.
"""
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__(checkpoint_io)
self._executor = ThreadPoolExecutor(max_workers=1)
self._error: Optional[BaseException] = None
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
try:
assert self.checkpoint_io is not None
self.checkpoint_io.save_checkpoint(*args, **kwargs)
except BaseException as e:
self._error = e
self._executor.submit(_save_checkpoint, *args, **kwargs)
# if an error was raised between the previous time `save_checkpoint`` was called and now,
# because `executor.submit` is not blocking
if self._error:
raise self._error
def teardown(self) -> None:
"""This method is called to close the threads."""
self._executor.shutdown(wait=True)
# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error

View File

@ -43,12 +43,13 @@ class CheckpointIO(ABC):
"""
@abstractmethod
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]:
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
Args:
path: Path to checkpoint
storage_options: Optional parameters when loading the model/training states.
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
Returns: The loaded checkpoint.
"""
@ -60,3 +61,6 @@ class CheckpointIO(ABC):
Args:
path: Path to checkpoint
"""
def teardown(self) -> None:
"""This method is called to teardown the process."""

View File

@ -69,7 +69,7 @@ class TorchCheckpointIO(CheckpointIO):
Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
locations.
Returns: The loaded checkpoint.

View File

@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
class _WrappingCheckpointIO(CheckpointIO):
"""``_WrappingCheckpointIO`` is a wrapper checkpoint_io that uses a base checkpoint_io to handle checkpointing.
Args:
checkpoint_io: A checkpoint IO plugin that is used as the basis.
"""
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__()
self._checkpoint_io = checkpoint_io
self._base_checkpoint_io_configured: bool = False
if checkpoint_io is not None:
if isinstance(checkpoint_io, _WrappingCheckpointIO):
self._base_checkpoint_io_configured = checkpoint_io._base_checkpoint_io_configured
else:
self._base_checkpoint_io_configured = True
@property
def checkpoint_io(self) -> Optional["CheckpointIO"]:
return self._checkpoint_io
@checkpoint_io.setter
def checkpoint_io(self, checkpoint_io: "CheckpointIO") -> None:
assert not isinstance(checkpoint_io, _WrappingCheckpointIO)
if self._checkpoint_io is None:
self._base_checkpoint_io_configured = True
self._checkpoint_io = checkpoint_io
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO) and not self._base_checkpoint_io_configured:
self._base_checkpoint_io_configured = True
self._checkpoint_io.checkpoint_io = checkpoint_io
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the base ``checkpoint_io`` to save the checkpoint."""
assert self.checkpoint_io is not None
self.checkpoint_io.save_checkpoint(*args, **kwargs)
def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the base ``checkpoint_io`` to remove the checkpoint."""
assert self.checkpoint_io is not None
self.checkpoint_io.remove_checkpoint(*args, **kwargs)
def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Uses the base ``checkpoint_io`` to load the checkpoint."""
assert self.checkpoint_io is not None
return self.checkpoint_io.load_checkpoint(*args, **kwargs)

View File

@ -23,6 +23,7 @@ from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities.distributed import group as _group
@ -78,6 +79,9 @@ class HPUParallelStrategy(DDPStrategy):
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = HPUCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
return self._checkpoint_io
@checkpoint_io.setter

View File

@ -17,6 +17,7 @@ from typing import Dict, Optional
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities import _HPU_AVAILABLE
@ -54,6 +55,9 @@ class SingleHPUStrategy(SingleDeviceStrategy):
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = HPUCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
return self._checkpoint_io
@checkpoint_io.setter

View File

@ -16,6 +16,7 @@ from typing import Dict, Optional
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
@ -50,6 +51,9 @@ class SingleTPUStrategy(SingleDeviceStrategy):
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@checkpoint_io.setter

View File

@ -27,6 +27,7 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers,
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn
@ -84,6 +85,8 @@ class Strategy(ABC):
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = TorchCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = TorchCheckpointIO()
return self._checkpoint_io
@ -467,6 +470,7 @@ class Strategy(ABC):
self.precision_plugin.teardown()
assert self.accelerator is not None
self.accelerator.teardown()
self.checkpoint_io.teardown()
@classmethod
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:

View File

@ -24,6 +24,7 @@ import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments import XLAEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
@ -78,6 +79,9 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@checkpoint_io.setter

View File

@ -62,7 +62,6 @@ def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
bytesbuffer = io.BytesIO()
torch.save(checkpoint, bytesbuffer)
with fsspec.open(filepath, "wb") as f:

View File

@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.utilities.types import _PATH
@ -49,9 +52,16 @@ def test_checkpoint_plugin_called(tmpdir):
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
limit_test_batches=1,
)
trainer.fit(model)
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1
@ -68,12 +78,76 @@ def test_checkpoint_plugin_called(tmpdir):
plugins=[checkpoint_plugin],
callbacks=ck,
max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
limit_test_batches=1,
)
trainer.fit(model)
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1
trainer.test(model, ckpt_path=ck.last_model_path)
checkpoint_plugin.load_checkpoint.assert_called_once()
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")
def test_async_checkpoint_plugin(tmpdir):
"""Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when async saving and
loading."""
checkpoint_plugin = AsyncCheckpointIO()
checkpoint_plugin.save_checkpoint = Mock(wraps=checkpoint_plugin.save_checkpoint)
checkpoint_plugin.remove_checkpoint = Mock(wraps=checkpoint_plugin.remove_checkpoint)
class CustomBoringModel(BoringModel):
def on_fit_start(self):
base_ckpt_io = self.trainer.strategy.checkpoint_io.checkpoint_io
base_ckpt_io.save_checkpoint = Mock(wraps=base_ckpt_io.save_checkpoint)
base_ckpt_io.remove_checkpoint = Mock(wraps=base_ckpt_io.remove_checkpoint)
ck = ModelCheckpoint(dirpath=tmpdir, save_top_k=2, monitor="step", mode="max")
model = CustomBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[checkpoint_plugin],
callbacks=ck,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
assert checkpoint_plugin.save_checkpoint.call_count == 3
assert checkpoint_plugin.remove_checkpoint.call_count == 1
base_ckpt_io = trainer.strategy.checkpoint_io.checkpoint_io
assert base_ckpt_io.save_checkpoint.call_count == 3
assert base_ckpt_io.remove_checkpoint.call_count == 1
def test_multi_wrapped_checkpoint_io_initialization():
base_ckpt_io = TorchCheckpointIO()
wrap_ckpt = AsyncCheckpointIO(base_ckpt_io)
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
assert ckpt_io.checkpoint_io is wrap_ckpt
assert ckpt_io.checkpoint_io.checkpoint_io is base_ckpt_io
assert ckpt_io._base_checkpoint_io_configured is True
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True
wrap_ckpt = AsyncCheckpointIO()
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
trainer = Trainer(accelerator="cpu", plugins=[ckpt_io])
trainer.strategy.checkpoint_io
assert ckpt_io.checkpoint_io is wrap_ckpt
assert isinstance(ckpt_io.checkpoint_io.checkpoint_io, TorchCheckpointIO)
assert ckpt_io._base_checkpoint_io_configured is True
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True