Deprecate `ModelCheckpoint.save_checkpoint` (#12456)
This commit is contained in:
parent
fe12bae704
commit
334675e710
|
@ -582,6 +582,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
|
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
|
||||||
|
|
||||||
|
|
||||||
|
- Deprecated `ModelCheckpoint.save_checkpoint` in favor of `Trainer.save_checkpoint` ([#12456](https://github.com/PyTorchLightning/pytorch-lightning/pull/12456))
|
||||||
|
|
||||||
|
|
||||||
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
|
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ from pytorch_lightning.callbacks.base import Callback
|
||||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.logger import _name, _version
|
from pytorch_lightning.utilities.logger import _name, _version
|
||||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||||
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
||||||
from pytorch_lightning.utilities.warnings import WarningCache
|
from pytorch_lightning.utilities.warnings import WarningCache
|
||||||
|
|
||||||
|
@ -353,7 +353,10 @@ class ModelCheckpoint(Callback):
|
||||||
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
||||||
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
||||||
"""
|
"""
|
||||||
# TODO: unused method. deprecate it
|
rank_zero_deprecation(
|
||||||
|
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
|
||||||
|
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
|
||||||
|
)
|
||||||
monitor_candidates = self._monitor_candidates(trainer)
|
monitor_candidates = self._monitor_candidates(trainer)
|
||||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from logging import INFO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import call, MagicMock, Mock, patch
|
from unittest.mock import call, Mock, patch
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -834,7 +834,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode):
|
||||||
val_check_interval=1.0,
|
val_check_interval=1.0,
|
||||||
max_epochs=len(monitor),
|
max_epochs=len(monitor),
|
||||||
)
|
)
|
||||||
trainer.save_checkpoint = MagicMock()
|
trainer.save_checkpoint = Mock()
|
||||||
|
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
@ -1309,9 +1309,10 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
|
||||||
def test_last_global_step_saved():
|
def test_last_global_step_saved():
|
||||||
# this should not save anything
|
# this should not save anything
|
||||||
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
|
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
|
||||||
trainer = MagicMock()
|
trainer = Mock()
|
||||||
trainer.callback_metrics = {"foo": 123}
|
monitor_candidates = {"foo": 123}
|
||||||
model_checkpoint.save_checkpoint(trainer)
|
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
|
||||||
|
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
|
||||||
assert model_checkpoint._last_global_step_saved == 0
|
assert model_checkpoint._last_global_step_saved == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from torch import optim
|
||||||
|
|
||||||
import pytorch_lightning
|
import pytorch_lightning
|
||||||
from pytorch_lightning import Callback, Trainer
|
from pytorch_lightning import Callback, Trainer
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
|
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
|
||||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||||
|
@ -1055,6 +1056,15 @@ def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_
|
||||||
assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids
|
assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecated_mc_save_checkpoint():
|
||||||
|
mc = ModelCheckpoint()
|
||||||
|
trainer = Trainer()
|
||||||
|
with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call(
|
||||||
|
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
|
||||||
|
):
|
||||||
|
mc.save_checkpoint(trainer)
|
||||||
|
|
||||||
|
|
||||||
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
|
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
|
||||||
class TestCallbackLoadHook(Callback):
|
class TestCallbackLoadHook(Callback):
|
||||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||||
|
|
Loading…
Reference in New Issue