diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 96a061a941..db4fc1e2c4 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -210,6 +210,7 @@ io :nosignatures: :template: classtemplate.rst + AsyncCheckpointIO CheckpointIO HPUCheckpointIO TorchCheckpointIO diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index c4a948a34c..665acfeef5 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -45,6 +45,10 @@ Built-in Checkpoint IO Plugins respectively, common for most use cases. * - :class:`~pytorch_lightning.plugins.io.XLACheckpointIO` - CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies. + * - :class:`~pytorch_lightning.plugins.io.HPUCheckpointIO` + - CheckpointIO to save checkpoints for HPU training strategies. + * - :class:`~pytorch_lightning.plugins.io.AsyncCheckpointIO` + - ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread. *************************** @@ -94,3 +98,36 @@ Custom Checkpoint IO Plugin .. note:: Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable. + + +************************** +Asynchronous Checkpointing +************************** + +.. warning:: + + This is currently an experimental plugin/feature and API changes are to be expected. + +To enable saving the checkpoints asynchronously without blocking your training, you can configure +:class:`~pytorch_lightning.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``. + +.. code-block:: python + + from pytorch_lightning.plugins.io import AsyncCheckpointIO + + + async_ckpt_io = AsyncCheckpointIO() + trainer = Trainer(plugins=[async_ckpt_io]) + + +It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously. +By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``. +But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``. + +.. code-block:: python + + from pytorch_lightning.plugins.io import AsyncCheckpointIO + + base_ckpt_io = MyCustomCheckpointIO() + async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io) + trainer = Trainer(plugins=[async_ckpt_io]) diff --git a/docs/source-pytorch/extensions/plugins.rst b/docs/source-pytorch/extensions/plugins.rst index 6ea8d42815..a0dbefd141 100644 --- a/docs/source-pytorch/extensions/plugins.rst +++ b/docs/source-pytorch/extensions/plugins.rst @@ -87,6 +87,7 @@ Below is a list of built-in plugins for checkpointing. :nosignatures: :template: classtemplate.rst + AsyncCheckpointIO CheckpointIO HPUCheckpointIO TorchCheckpointIO diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1c3a3b9d5a..327b03c3aa 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405)) +- Added support for async checkpointing ([#13658](https://github.com/PyTorchLightning/pytorch-lightning/pull/13658)) + + ### Changed - `accelerator="gpu"` now automatically selects an available GPU backend (CUDA and MPS currently) ([#13642](https://github.com/Lightning-AI/lightning/pull/13642)) diff --git a/src/pytorch_lightning/plugins/__init__.py b/src/pytorch_lightning/plugins/__init__.py index 0f1c4ca85e..afd10c88c9 100644 --- a/src/pytorch_lightning/plugins/__init__.py +++ b/src/pytorch_lightning/plugins/__init__.py @@ -1,6 +1,7 @@ from typing import Union from pytorch_lightning.plugins.environments import ClusterEnvironment +from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO @@ -38,6 +39,7 @@ PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, Laye PLUGIN_INPUT = Union[PLUGIN, str] __all__ = [ + "AsyncCheckpointIO", "CheckpointIO", "TorchCheckpointIO", "XLACheckpointIO", diff --git a/src/pytorch_lightning/plugins/io/__init__.py b/src/pytorch_lightning/plugins/io/__init__.py index abd196eb2b..19a556bddf 100644 --- a/src/pytorch_lightning/plugins/io/__init__.py +++ b/src/pytorch_lightning/plugins/io/__init__.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 -from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401 -from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 -from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO +from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO + +__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"] diff --git a/src/pytorch_lightning/plugins/io/async_plugin.py b/src/pytorch_lightning/plugins/io/async_plugin.py new file mode 100644 index 0000000000..1146bc373a --- /dev/null +++ b/src/pytorch_lightning/plugins/io/async_plugin.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO + + +class AsyncCheckpointIO(_WrappingCheckpointIO): + """``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread. + + .. warning:: + + This is currently an experimental plugin/feature and API changes are to be expected. + + Args: + checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing. + """ + + def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: + super().__init__(checkpoint_io) + + self._executor = ThreadPoolExecutor(max_workers=1) + self._error: Optional[BaseException] = None + + def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: + """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" + + def _save_checkpoint(*args: Any, **kwargs: Any) -> None: + try: + assert self.checkpoint_io is not None + self.checkpoint_io.save_checkpoint(*args, **kwargs) + except BaseException as e: + self._error = e + + self._executor.submit(_save_checkpoint, *args, **kwargs) + + # if an error was raised between the previous time `save_checkpoint`` was called and now, + # because `executor.submit` is not blocking + if self._error: + raise self._error + + def teardown(self) -> None: + """This method is called to close the threads.""" + self._executor.shutdown(wait=True) + + # if an error was raised anytime in any of the `executor.submit` calls + if self._error: + raise self._error diff --git a/src/pytorch_lightning/plugins/io/checkpoint_plugin.py b/src/pytorch_lightning/plugins/io/checkpoint_plugin.py index 1425a22996..7dcc850424 100644 --- a/src/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/src/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -43,12 +43,13 @@ class CheckpointIO(ABC): """ @abstractmethod - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint - storage_options: Optional parameters when loading the model/training states. + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage + locations. Returns: The loaded checkpoint. """ @@ -60,3 +61,6 @@ class CheckpointIO(ABC): Args: path: Path to checkpoint """ + + def teardown(self) -> None: + """This method is called to teardown the process.""" diff --git a/src/pytorch_lightning/plugins/io/torch_plugin.py b/src/pytorch_lightning/plugins/io/torch_plugin.py index 8791249e7d..0e5cba3837 100644 --- a/src/pytorch_lightning/plugins/io/torch_plugin.py +++ b/src/pytorch_lightning/plugins/io/torch_plugin.py @@ -69,7 +69,7 @@ class TorchCheckpointIO(CheckpointIO): Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage - locations. + locations. Returns: The loaded checkpoint. diff --git a/src/pytorch_lightning/plugins/io/wrapper.py b/src/pytorch_lightning/plugins/io/wrapper.py new file mode 100644 index 0000000000..eb46990def --- /dev/null +++ b/src/pytorch_lightning/plugins/io/wrapper.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO + + +class _WrappingCheckpointIO(CheckpointIO): + """``_WrappingCheckpointIO`` is a wrapper checkpoint_io that uses a base checkpoint_io to handle checkpointing. + + Args: + checkpoint_io: A checkpoint IO plugin that is used as the basis. + """ + + def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: + super().__init__() + + self._checkpoint_io = checkpoint_io + self._base_checkpoint_io_configured: bool = False + + if checkpoint_io is not None: + if isinstance(checkpoint_io, _WrappingCheckpointIO): + self._base_checkpoint_io_configured = checkpoint_io._base_checkpoint_io_configured + else: + self._base_checkpoint_io_configured = True + + @property + def checkpoint_io(self) -> Optional["CheckpointIO"]: + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, checkpoint_io: "CheckpointIO") -> None: + assert not isinstance(checkpoint_io, _WrappingCheckpointIO) + + if self._checkpoint_io is None: + self._base_checkpoint_io_configured = True + self._checkpoint_io = checkpoint_io + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO) and not self._base_checkpoint_io_configured: + self._base_checkpoint_io_configured = True + self._checkpoint_io.checkpoint_io = checkpoint_io + + def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: + """Uses the base ``checkpoint_io`` to save the checkpoint.""" + assert self.checkpoint_io is not None + self.checkpoint_io.save_checkpoint(*args, **kwargs) + + def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None: + """Uses the base ``checkpoint_io`` to remove the checkpoint.""" + assert self.checkpoint_io is not None + self.checkpoint_io.remove_checkpoint(*args, **kwargs) + + def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Uses the base ``checkpoint_io`` to load the checkpoint.""" + assert self.checkpoint_io is not None + return self.checkpoint_io.load_checkpoint(*args, **kwargs) diff --git a/src/pytorch_lightning/strategies/hpu_parallel.py b/src/pytorch_lightning/strategies/hpu_parallel.py index 591664e93e..3e6f8e932e 100644 --- a/src/pytorch_lightning/strategies/hpu_parallel.py +++ b/src/pytorch_lightning/strategies/hpu_parallel.py @@ -23,6 +23,7 @@ from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.utilities.distributed import group as _group @@ -78,6 +79,9 @@ class HPUParallelStrategy(DDPStrategy): def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = HPUCheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = HPUCheckpointIO() + return self._checkpoint_io @checkpoint_io.setter diff --git a/src/pytorch_lightning/strategies/single_hpu.py b/src/pytorch_lightning/strategies/single_hpu.py index bbba3904f6..45eb8c58f2 100644 --- a/src/pytorch_lightning/strategies/single_hpu.py +++ b/src/pytorch_lightning/strategies/single_hpu.py @@ -17,6 +17,7 @@ from typing import Dict, Optional import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.utilities import _HPU_AVAILABLE @@ -54,6 +55,9 @@ class SingleHPUStrategy(SingleDeviceStrategy): def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = HPUCheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = HPUCheckpointIO() + return self._checkpoint_io @checkpoint_io.setter diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index caf153ace0..3084f17430 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -16,6 +16,7 @@ from typing import Dict, Optional import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy @@ -50,6 +51,9 @@ class SingleTPUStrategy(SingleDeviceStrategy): def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = XLACheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = XLACheckpointIO() + return self._checkpoint_io @checkpoint_io.setter diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 3d45c61abb..f47afc890b 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -27,6 +27,7 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn @@ -84,6 +85,8 @@ class Strategy(ABC): def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = TorchCheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = TorchCheckpointIO() return self._checkpoint_io @@ -467,6 +470,7 @@ class Strategy(ABC): self.precision_plugin.teardown() assert self.accelerator is not None self.accelerator.teardown() + self.checkpoint_io.teardown() @classmethod def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 0c02c82084..f4953a9f64 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -24,6 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy @@ -78,6 +79,9 @@ class TPUSpawnStrategy(DDPSpawnStrategy): def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = XLACheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = XLACheckpointIO() + return self._checkpoint_io @checkpoint_io.setter diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 9055ff50c8..81482a8ab2 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -62,7 +62,6 @@ def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ - bytesbuffer = io.BytesIO() torch.save(checkpoint, bytesbuffer) with fsspec.open(filepath, "wb") as f: diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 651ab1cc4f..ae618ffa33 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from pathlib import Path from typing import Any, Dict, Optional -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.plugins import CheckpointIO +from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.strategies import SingleDeviceStrategy from pytorch_lightning.utilities.types import _PATH @@ -49,9 +52,16 @@ def test_checkpoint_plugin_called(tmpdir): strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin), callbacks=ck, max_epochs=2, + limit_train_batches=1, + limit_val_batches=0, + limit_test_batches=1, ) trainer.fit(model) + ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")} + assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"} + assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt" + assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt" assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 @@ -68,12 +78,76 @@ def test_checkpoint_plugin_called(tmpdir): plugins=[checkpoint_plugin], callbacks=ck, max_epochs=2, + limit_train_batches=1, + limit_val_batches=0, + limit_test_batches=1, ) trainer.fit(model) + ckpt_files = {fn.name for fn in Path(tmpdir).glob("*.ckpt")} + assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"} + assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt" + assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt" assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last-v1.ckpt") + + +def test_async_checkpoint_plugin(tmpdir): + """Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when async saving and + loading.""" + + checkpoint_plugin = AsyncCheckpointIO() + + checkpoint_plugin.save_checkpoint = Mock(wraps=checkpoint_plugin.save_checkpoint) + checkpoint_plugin.remove_checkpoint = Mock(wraps=checkpoint_plugin.remove_checkpoint) + + class CustomBoringModel(BoringModel): + def on_fit_start(self): + base_ckpt_io = self.trainer.strategy.checkpoint_io.checkpoint_io + base_ckpt_io.save_checkpoint = Mock(wraps=base_ckpt_io.save_checkpoint) + base_ckpt_io.remove_checkpoint = Mock(wraps=base_ckpt_io.remove_checkpoint) + + ck = ModelCheckpoint(dirpath=tmpdir, save_top_k=2, monitor="step", mode="max") + + model = CustomBoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[checkpoint_plugin], + callbacks=ck, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model) + + assert checkpoint_plugin.save_checkpoint.call_count == 3 + assert checkpoint_plugin.remove_checkpoint.call_count == 1 + + base_ckpt_io = trainer.strategy.checkpoint_io.checkpoint_io + assert base_ckpt_io.save_checkpoint.call_count == 3 + assert base_ckpt_io.remove_checkpoint.call_count == 1 + + +def test_multi_wrapped_checkpoint_io_initialization(): + base_ckpt_io = TorchCheckpointIO() + wrap_ckpt = AsyncCheckpointIO(base_ckpt_io) + ckpt_io = AsyncCheckpointIO(wrap_ckpt) + assert ckpt_io.checkpoint_io is wrap_ckpt + assert ckpt_io.checkpoint_io.checkpoint_io is base_ckpt_io + assert ckpt_io._base_checkpoint_io_configured is True + assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True + + wrap_ckpt = AsyncCheckpointIO() + ckpt_io = AsyncCheckpointIO(wrap_ckpt) + trainer = Trainer(accelerator="cpu", plugins=[ckpt_io]) + trainer.strategy.checkpoint_io + assert ckpt_io.checkpoint_io is wrap_ckpt + assert isinstance(ckpt_io.checkpoint_io.checkpoint_io, TorchCheckpointIO) + assert ckpt_io._base_checkpoint_io_configured is True + assert ckpt_io.checkpoint_io._base_checkpoint_io_configured is True