Deprecate `ModelCheckpoint.save_checkpoint` (#12456)

This commit is contained in:
Carlos Mocholí 2022-03-28 19:36:26 +02:00 committed by GitHub
parent fe12bae704
commit 334675e710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 7 deletions

View File

@ -582,6 +582,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
- Deprecated `ModelCheckpoint.save_checkpoint` in favor of `Trainer.save_checkpoint` ([#12456](https://github.com/PyTorchLightning/pytorch-lightning/pull/12456))
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))

View File

@ -37,7 +37,7 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
@ -353,7 +353,10 @@ class ModelCheckpoint(Callback):
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
"""
# TODO: unused method. deprecate it
rank_zero_deprecation(
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
)
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

View File

@ -22,7 +22,7 @@ from logging import INFO
from pathlib import Path
from typing import Union
from unittest import mock
from unittest.mock import call, MagicMock, Mock, patch
from unittest.mock import call, Mock, patch
import cloudpickle
import pytest
@ -834,7 +834,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode):
val_check_interval=1.0,
max_epochs=len(monitor),
)
trainer.save_checkpoint = MagicMock()
trainer.save_checkpoint = Mock()
trainer.fit(model)
@ -1309,9 +1309,10 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = MagicMock()
trainer.callback_metrics = {"foo": 123}
model_checkpoint.save_checkpoint(trainer)
trainer = Mock()
monitor_candidates = {"foo": 123}
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
assert model_checkpoint._last_global_step_saved == 0

View File

@ -24,6 +24,7 @@ from torch import optim
import pytorch_lightning
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@ -1055,6 +1056,15 @@ def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_
assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids
def test_deprecated_mc_save_checkpoint():
mc = ModelCheckpoint()
trainer = Trainer()
with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call(
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
):
mc.save_checkpoint(trainer)
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):