Deprecate save_function from model checkpoint callback (#7201)
* Update model_checkpoint.py * Update CHANGELOG.md * fix-tests * deprecate not remove * Update model_checkpoint.py * Update test_remove_1-5.py
This commit is contained in:
parent
ac7d6a35c3
commit
dd5ec75e48
|
@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
|
||||
|
||||
|
||||
|
@ -190,6 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
|
||||
- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import os
|
|||
import re
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -201,19 +201,19 @@ class ModelCheckpoint(Callback):
|
|||
self.best_model_score = None
|
||||
self.best_model_path = ""
|
||||
self.last_model_path = ""
|
||||
self.save_function = None
|
||||
|
||||
self.__init_monitor_mode(monitor, mode)
|
||||
self.__init_ckpt_dir(dirpath, filename, save_top_k)
|
||||
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
|
||||
self.__validate_init_configuration()
|
||||
self._save_function = None
|
||||
|
||||
def on_pretrain_routine_start(self, trainer, pl_module):
|
||||
"""
|
||||
When pretrain routine starts we build the ckpt dir on the fly
|
||||
"""
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
self.save_function = trainer.save_checkpoint
|
||||
self._save_function = trainer.save_checkpoint
|
||||
|
||||
def on_train_batch_end(
|
||||
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
|
@ -254,9 +254,9 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def save_checkpoint(self, trainer, unused: Optional = None):
|
||||
"""
|
||||
Performs the main logic around saving a checkpoint.
|
||||
This method runs on all ranks, it is the responsibility of `self.save_function`
|
||||
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
|
||||
Performs the main logic around saving a checkpoint. This method runs on all ranks.
|
||||
It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training,
|
||||
i.e., saving only on rank 0 for data parallel use cases.
|
||||
"""
|
||||
if unused is not None:
|
||||
rank_zero_deprecation(
|
||||
|
@ -396,6 +396,22 @@ class ModelCheckpoint(Callback):
|
|||
)
|
||||
self._period = value
|
||||
|
||||
@property
|
||||
def save_function(self) -> Optional[Callable]:
|
||||
rank_zero_deprecation(
|
||||
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
|
||||
' Please use `trainer.save_checkpoint` instead.'
|
||||
)
|
||||
return self._save_function
|
||||
|
||||
@save_function.setter
|
||||
def save_function(self, value: Optional[Callable]) -> None:
|
||||
rank_zero_deprecation(
|
||||
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
|
||||
' Please use `trainer.save_checkpoint` instead.'
|
||||
)
|
||||
self._save_function = value
|
||||
|
||||
@rank_zero_only
|
||||
def _del_model(self, filepath: str):
|
||||
if self._fs.exists(filepath):
|
||||
|
@ -420,10 +436,7 @@ class ModelCheckpoint(Callback):
|
|||
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# delegate the saving to the trainer
|
||||
if self.save_function is not None:
|
||||
self.save_function(filepath, self.save_weights_only)
|
||||
else:
|
||||
raise ValueError(".save_function() not set")
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
|
||||
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
|
||||
if current is None:
|
||||
|
|
|
@ -32,9 +32,18 @@ from tests.helpers.utils import no_warning_call
|
|||
|
||||
def test_v1_5_0_model_checkpoint_save_checkpoint():
|
||||
model_ckpt = ModelCheckpoint()
|
||||
model_ckpt.save_function = lambda *_, **__: None
|
||||
trainer = Trainer()
|
||||
trainer.save_checkpoint = lambda *_, **__: None
|
||||
with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"):
|
||||
model_ckpt.save_checkpoint(Trainer(), object())
|
||||
model_ckpt.save_checkpoint(trainer, object())
|
||||
|
||||
|
||||
def test_v1_5_0_model_checkpoint_save_function():
|
||||
model_ckpt = ModelCheckpoint()
|
||||
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
|
||||
model_ckpt.save_function = lambda *_, **__: None
|
||||
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
|
||||
_ = model_ckpt.save_function
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
|
|
|
@ -330,9 +330,9 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
|
|||
save_last=save_last,
|
||||
verbose=True
|
||||
)
|
||||
checkpoint_callback.save_function = mock_save_function
|
||||
trainer = Trainer()
|
||||
trainer.state = TrainerState.FITTING
|
||||
trainer.save_checkpoint = mock_save_function
|
||||
|
||||
# emulate callback's calls during the training
|
||||
for i, loss in enumerate(losses):
|
||||
|
|
Loading…
Reference in New Issue