Add support for async checkpointing (#13658)

This commit is contained in:
Rohit Gupta 2022-07-26 21:13:19 +05:30 committed by GitHub
parent 9c720c8adf
commit faf7ff57c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 282 additions and 10 deletions

View File

@ -210,6 +210,7 @@ io
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
AsyncCheckpointIO
CheckpointIO CheckpointIO
HPUCheckpointIO HPUCheckpointIO
TorchCheckpointIO TorchCheckpointIO

View File

@ -45,6 +45,10 @@ Built-in Checkpoint IO Plugins
respectively, common for most use cases. respectively, common for most use cases.
* - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO` * - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO`
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies. - CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
* - :class:`~pytorch_lightning.plugins.io.HPUCheckpointIO`
- CheckpointIO to save checkpoints for HPU training strategies.
* - :class:`~pytorch_lightning.plugins.io.AsyncCheckpointIO`
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
*************************** ***************************
@ -94,3 +98,36 @@ Custom Checkpoint IO Plugin
.. note:: .. note::
Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable. Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
**************************
Asynchronous Checkpointing
**************************
.. warning::
This is currently an experimental plugin/feature and API changes are to be expected.
To enable saving the checkpoints asynchronously without blocking your training, you can configure
:class:`~pytorch_lightning.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
.. code-block:: python
from pytorch_lightning.plugins.io import AsyncCheckpointIO
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
.. code-block:: python
from pytorch_lightning.plugins.io import AsyncCheckpointIO
base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])

View File

@ -87,6 +87,7 @@ Below is a list of built-in plugins for checkpointing.
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
AsyncCheckpointIO
CheckpointIO CheckpointIO
HPUCheckpointIO HPUCheckpointIO
TorchCheckpointIO TorchCheckpointIO

View File

@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405)) - Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405))
- Added support for async checkpointing ([#13658](https://github.com/PyTorchLightning/pytorch-lightning/pull/13658))
### Changed ### Changed
- `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642)) - `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642))

View File

@ -1,6 +1,7 @@
from typing import Union from typing import Union
from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
@ -38,6 +39,7 @@ PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, Laye
PLUGIN_INPUT = Union[PLUGIN, str] PLUGIN_INPUT = Union[PLUGIN, str]
__all__ = [ __all__ = [
"AsyncCheckpointIO",
"CheckpointIO", "CheckpointIO",
"TorchCheckpointIO", "TorchCheckpointIO",
"XLACheckpointIO", "XLACheckpointIO",

View File

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"]

View File

@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
class AsyncCheckpointIO(_WrappingCheckpointIO):
"""``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
.. warning::
This is currently an experimental plugin/feature and API changes are to be expected.
Args:
checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing.
"""
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__(checkpoint_io)
self._executor = ThreadPoolExecutor(max_workers=1)
self._error: Optional[BaseException] = None
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
try:
assert self.checkpoint_io is not None
self.checkpoint_io.save_checkpoint(*args, **kwargs)
except BaseException as e:
self._error = e
self._executor.submit(_save_checkpoint, *args, **kwargs)
# if an error was raised between the previous time `save_checkpoint`` was called and now,
# because `executor.submit` is not blocking
if self._error:
raise self._error
def teardown(self) -> None:
"""This method is called to close the threads."""
self._executor.shutdown(wait=True)
# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error

View File

@ -43,12 +43,13 @@ class CheckpointIO(ABC):
""" """
@abstractmethod @abstractmethod
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]:
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
Args: Args:
path: Path to checkpoint path: Path to checkpoint
storage_options: Optional parameters when loading the model/training states. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
Returns: The loaded checkpoint. Returns: The loaded checkpoint.
""" """
@ -60,3 +61,6 @@ class CheckpointIO(ABC):
Args: Args:
path: Path to checkpoint path: Path to checkpoint
""" """
def teardown(self) -> None:
"""This method is called to teardown the process."""

View File

@ -69,7 +69,7 @@ class TorchCheckpointIO(CheckpointIO):
Args: Args:
path: Path to checkpoint path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations. locations.
Returns: The loaded checkpoint. Returns: The loaded checkpoint.

View File

@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
class _WrappingCheckpointIO(CheckpointIO):
"""``_WrappingCheckpointIO`` is a wrapper checkpoint_io that uses a base checkpoint_io to handle checkpointing.
Args:
checkpoint_io: A checkpoint IO plugin that is used as the basis.
"""
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__()
self._checkpoint_io = checkpoint_io
self._base_checkpoint_io_configured: bool = False
if checkpoint_io is not None:
if isinstance(checkpoint_io, _WrappingCheckpointIO):
self._base_checkpoint_io_configured = checkpoint_io._base_checkpoint_io_configured
else:
self._base_checkpoint_io_configured = True
@property
def checkpoint_io(self) -> Optional["CheckpointIO"]:
return self._checkpoint_io
@checkpoint_io.setter
def checkpoint_io(self, checkpoint_io: "CheckpointIO") -> None:
assert not isinstance(checkpoint_io, _WrappingCheckpointIO)
if self._checkpoint_io is None:
self._base_checkpoint_io_configured = True
self._checkpoint_io = checkpoint_io
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO) and not self._base_checkpoint_io_configured:
self._base_checkpoint_io_configured = True
self._checkpoint_io.checkpoint_io = checkpoint_io
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the base ``checkpoint_io`` to save the checkpoint."""
assert self.checkpoint_io is not None
self.checkpoint_io.save_checkpoint(*args, **kwargs)
def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the base ``checkpoint_io`` to remove the checkpoint."""
assert self.checkpoint_io is not None
self.checkpoint_io.remove_checkpoint(*args, **kwargs)
def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Uses the base ``checkpoint_io`` to load the checkpoint."""
assert self.checkpoint_io is not None
return self.checkpoint_io.load_checkpoint(*args, **kwargs)

View File

@ -23,6 +23,7 @@ from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import group as _group
@ -78,6 +79,9 @@ class HPUParallelStrategy(DDPStrategy):
def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None: if self._checkpoint_io is None:
self._checkpoint_io = HPUCheckpointIO() self._checkpoint_io = HPUCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
return self._checkpoint_io return self._checkpoint_io
@checkpoint_io.setter @checkpoint_io.setter

View File

@ -17,6 +17,7 @@ from typing import Dict, Optional
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities import _HPU_AVAILABLE from pytorch_lightning.utilities import _HPU_AVAILABLE
@ -54,6 +55,9 @@ class SingleHPUStrategy(SingleDeviceStrategy):
def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None: if self._checkpoint_io is None:
self._checkpoint_io = HPUCheckpointIO() self._checkpoint_io = HPUCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = HPUCheckpointIO()
return self._checkpoint_io return self._checkpoint_io
@checkpoint_io.setter @checkpoint_io.setter

View File

@ -16,6 +16,7 @@ from typing import Dict, Optional
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
@ -50,6 +51,9 @@ class SingleTPUStrategy(SingleDeviceStrategy):
def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None: if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO() self._checkpoint_io = XLACheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
return self._checkpoint_io return self._checkpoint_io
@checkpoint_io.setter @checkpoint_io.setter

View File

@ -27,6 +27,7 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers,
from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.states import TrainerFn
@ -84,6 +85,8 @@ class Strategy(ABC):
def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None: if self._checkpoint_io is None:
self._checkpoint_io = TorchCheckpointIO() self._checkpoint_io = TorchCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = TorchCheckpointIO()
return self._checkpoint_io return self._checkpoint_io
@ -467,6 +470,7 @@ class Strategy(ABC):
self.precision_plugin.teardown() self.precision_plugin.teardown()
assert self.accelerator is not None assert self.accelerator is not None
self.accelerator.teardown() self.accelerator.teardown()
self.checkpoint_io.teardown()
@classmethod @classmethod
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:

View File

@ -24,6 +24,7 @@ import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.environments import XLAEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
@ -78,6 +79,9 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None: if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO() self._checkpoint_io = XLACheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = XLACheckpointIO()
return self._checkpoint_io return self._checkpoint_io
@checkpoint_io.setter @checkpoint_io.setter

View File

@ -62,7 +62,6 @@ def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
filepath: The path to which the checkpoint will be saved. filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in. This points to the file that the checkpoint will be stored in.
""" """
bytesbuffer = io.BytesIO() bytesbuffer = io.BytesIO()
torch.save(checkpoint, bytesbuffer) torch.save(checkpoint, bytesbuffer)
with fsspec.open(filepath, "wb") as f: with fsspec.open(filepath, "wb") as f:

View File

@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock, Mock
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins import CheckpointIO from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.strategies import SingleDeviceStrategy from pytorch_lightning.strategies import SingleDeviceStrategy
from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.types import _PATH
@ -49,9 +52,16 @@ def test_checkpoint_plugin_called(tmpdir):
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin), strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck, callbacks=ck,
max_epochs=2, max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
limit_test_batches=1,
) )
trainer.fit(model) trainer.fit(model)
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1 assert checkpoint_plugin.remove_checkpoint.call_count == 1
@ -68,12 +78,76 @@ def test_checkpoint_plugin_called(tmpdir):
plugins=[checkpoint_plugin], plugins=[checkpoint_plugin],
callbacks=ck, callbacks=ck,
max_epochs=2, max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
limit_test_batches=1,
) )
trainer.fit(model) trainer.fit(model)
ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")}
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.remove_checkpoint.call_count == 1 assert checkpoint_plugin.remove_checkpoint.call_count == 1
trainer.test(model, ckpt_path=ck.last_model_path) trainer.test(model, ckpt_path=ck.last_model_path)
checkpoint_plugin.load_checkpoint.assert_called_once() checkpoint_plugin.load_checkpoint.assert_called_once()
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt") checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt")
def test_async_checkpoint_plugin(tmpdir):
"""Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when async saving and
loading."""
checkpoint_plugin = AsyncCheckpointIO()
checkpoint_plugin.save_checkpoint = Mock(wraps=checkpoint_plugin.save_checkpoint)
checkpoint_plugin.remove_checkpoint = Mock(wraps=checkpoint_plugin.remove_checkpoint)
class CustomBoringModel(BoringModel):
def on_fit_start(self):
base_ckpt_io = self.trainer.strategy.checkpoint_io.checkpoint_io
base_ckpt_io.save_checkpoint = Mock(wraps=base_ckpt_io.save_checkpoint)
base_ckpt_io.remove_checkpoint = Mock(wraps=base_ckpt_io.remove_checkpoint)
ck = ModelCheckpoint(dirpath=tmpdir, save_top_k=2, monitor="step", mode="max")
model = CustomBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[checkpoint_plugin],
callbacks=ck,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
assert checkpoint_plugin.save_checkpoint.call_count == 3
assert checkpoint_plugin.remove_checkpoint.call_count == 1
base_ckpt_io = trainer.strategy.checkpoint_io.checkpoint_io
assert base_ckpt_io.save_checkpoint.call_count == 3
assert base_ckpt_io.remove_checkpoint.call_count == 1
def test_multi_wrapped_checkpoint_io_initialization():
base_ckpt_io = TorchCheckpointIO()
wrap_ckpt = AsyncCheckpointIO(base_ckpt_io)
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
assert ckpt_io.checkpoint_io is wrap_ckpt
assert ckpt_io.checkpoint_io.checkpoint_io is base_ckpt_io
assert ckpt_io._base_checkpoint_io_configured is True
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True
wrap_ckpt = AsyncCheckpointIO()
ckpt_io = AsyncCheckpointIO(wrap_ckpt)
trainer = Trainer(accelerator="cpu", plugins=[ckpt_io])
trainer.strategy.checkpoint_io
assert ckpt_io.checkpoint_io is wrap_ckpt
assert isinstance(ckpt_io.checkpoint_io.checkpoint_io, TorchCheckpointIO)
assert ckpt_io._base_checkpoint_io_configured is True
assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True