From 334675e710d0699cbb49efea3752f58d5acdd131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 28 Mar 2022 19:36:26 +0200 Subject: [PATCH] Deprecate `ModelCheckpoint.save_checkpoint` (#12456) --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/model_checkpoint.py | 7 +++++-- tests/checkpointing/test_model_checkpoint.py | 11 ++++++----- tests/deprecated_api/test_remove_1-8.py | 10 ++++++++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98681e08e7..374a34e9b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -582,6 +582,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745)) +- Deprecated `ModelCheckpoint.save_checkpoint` in favor of `Trainer.save_checkpoint` ([#12456](https://github.com/PyTorchLightning/pytorch-lightning/pull/12456)) + + - Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e2a6b59afa..223e2639d9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -37,7 +37,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _name, _version -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -353,7 +353,10 @@ class ModelCheckpoint(Callback): This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ - # TODO: unused method. deprecate it + rank_zero_deprecation( + f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8." + " Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint." + ) monitor_candidates = self._monitor_candidates(trainer) self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f8ed4eb746..fba18f326c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -22,7 +22,7 @@ from logging import INFO from pathlib import Path from typing import Union from unittest import mock -from unittest.mock import call, MagicMock, Mock, patch +from unittest.mock import call, Mock, patch import cloudpickle import pytest @@ -834,7 +834,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode): val_check_interval=1.0, max_epochs=len(monitor), ) - trainer.save_checkpoint = MagicMock() + trainer.save_checkpoint = Mock() trainer.fit(model) @@ -1309,9 +1309,10 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir): def test_last_global_step_saved(): # this should not save anything model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo") - trainer = MagicMock() - trainer.callback_metrics = {"foo": 123} - model_checkpoint.save_checkpoint(trainer) + trainer = Mock() + monitor_candidates = {"foo": 123} + model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates) + model_checkpoint._save_last_checkpoint(trainer, monitor_candidates) assert model_checkpoint._last_global_step_saved == 0 diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index a7cabf1794..8b1f0d6e7c 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -24,6 +24,7 @@ from torch import optim import pytorch_lightning from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -1055,6 +1056,15 @@ def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_ assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids +def test_deprecated_mc_save_checkpoint(): + mc = ModelCheckpoint() + trainer = Trainer() + with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call( + match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6" + ): + mc.save_checkpoint(trainer) + + def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir): class TestCallbackLoadHook(Callback): def on_load_checkpoint(self, trainer, pl_module, callback_state):