Deprecate save_function from model checkpoint callback (#7201)

* Update model_checkpoint.py

* Update CHANGELOG.md

* fix-tests

* deprecate not remove

* Update model_checkpoint.py

* Update test_remove_1-5.py
This commit is contained in:
ananthsub 2021-04-26 09:55:26 -07:00 committed by GitHub
parent ac7d6a35c3
commit dd5ec75e48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 13 deletions

View File

@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))
- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
@ -190,6 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))

View File

@ -23,7 +23,7 @@ import os
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
@ -201,19 +201,19 @@ class ModelCheckpoint(Callback):
self.best_model_score = None
self.best_model_path = ""
self.last_model_path = ""
self.save_function = None
self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__validate_init_configuration()
self._save_function = None
def on_pretrain_routine_start(self, trainer, pl_module):
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint
self._save_function = trainer.save_checkpoint
def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
@ -254,9 +254,9 @@ class ModelCheckpoint(Callback):
def save_checkpoint(self, trainer, unused: Optional = None):
"""
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
Performs the main logic around saving a checkpoint. This method runs on all ranks.
It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training,
i.e., saving only on rank 0 for data parallel use cases.
"""
if unused is not None:
rank_zero_deprecation(
@ -396,6 +396,22 @@ class ModelCheckpoint(Callback):
)
self._period = value
@property
def save_function(self) -> Optional[Callable]:
rank_zero_deprecation(
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `trainer.save_checkpoint` instead.'
)
return self._save_function
@save_function.setter
def save_function(self, value: Optional[Callable]) -> None:
rank_zero_deprecation(
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `trainer.save_checkpoint` instead.'
)
self._save_function = value
@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
@ -420,10 +436,7 @@ class ModelCheckpoint(Callback):
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the trainer
if self.save_function is not None:
self.save_function(filepath, self.save_weights_only)
else:
raise ValueError(".save_function() not set")
trainer.save_checkpoint(filepath, self.save_weights_only)
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
if current is None:

View File

@ -32,9 +32,18 @@ from tests.helpers.utils import no_warning_call
def test_v1_5_0_model_checkpoint_save_checkpoint():
model_ckpt = ModelCheckpoint()
model_ckpt.save_function = lambda *_, **__: None
trainer = Trainer()
trainer.save_checkpoint = lambda *_, **__: None
with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"):
model_ckpt.save_checkpoint(Trainer(), object())
model_ckpt.save_checkpoint(trainer, object())
def test_v1_5_0_model_checkpoint_save_function():
model_ckpt = ModelCheckpoint()
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
model_ckpt.save_function = lambda *_, **__: None
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
_ = model_ckpt.save_function
@mock.patch('pytorch_lightning.loggers.wandb.wandb')

View File

@ -330,9 +330,9 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
save_last=save_last,
verbose=True
)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
trainer.state = TrainerState.FITTING
trainer.save_checkpoint = mock_save_function
# emulate callback's calls during the training
for i, loss in enumerate(losses):