From dd5ec75e4827a690b02f5929a9fce7ce41763add Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Apr 2021 09:55:26 -0700 Subject: [PATCH] Deprecate save_function from model checkpoint callback (#7201) * Update model_checkpoint.py * Update CHANGELOG.md * fix-tests * deprecate not remove * Update model_checkpoint.py * Update test_remove_1-5.py --- CHANGELOG.md | 4 +++ .../callbacks/model_checkpoint.py | 33 +++++++++++++------ tests/deprecated_api/test_remove_1-5.py | 13 ++++++-- tests/trainer/test_trainer.py | 2 +- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ffdd8085f..23e7ced49d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201)) + + - Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066)) @@ -190,6 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed + - Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 29b6d8681d..5d1f6cedb5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -23,7 +23,7 @@ import os import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch @@ -201,19 +201,19 @@ class ModelCheckpoint(Callback): self.best_model_score = None self.best_model_path = "" self.last_model_path = "" - self.save_function = None self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() + self._save_function = None def on_pretrain_routine_start(self, trainer, pl_module): """ When pretrain routine starts we build the ckpt dir on the fly """ self.__resolve_ckpt_dir(trainer) - self.save_function = trainer.save_checkpoint + self._save_function = trainer.save_checkpoint def on_train_batch_end( self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int @@ -254,9 +254,9 @@ class ModelCheckpoint(Callback): def save_checkpoint(self, trainer, unused: Optional = None): """ - Performs the main logic around saving a checkpoint. - This method runs on all ranks, it is the responsibility of `self.save_function` - to handle correct behaviour in distributed training, i.e., saving only on rank 0. + Performs the main logic around saving a checkpoint. This method runs on all ranks. + It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, + i.e., saving only on rank 0 for data parallel use cases. """ if unused is not None: rank_zero_deprecation( @@ -396,6 +396,22 @@ class ModelCheckpoint(Callback): ) self._period = value + @property + def save_function(self) -> Optional[Callable]: + rank_zero_deprecation( + 'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `trainer.save_checkpoint` instead.' + ) + return self._save_function + + @save_function.setter + def save_function(self, value: Optional[Callable]) -> None: + rank_zero_deprecation( + 'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `trainer.save_checkpoint` instead.' + ) + self._save_function = value + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): @@ -420,10 +436,7 @@ class ModelCheckpoint(Callback): self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the trainer - if self.save_function is not None: - self.save_function(filepath, self.save_weights_only) - else: - raise ValueError(".save_function() not set") + trainer.save_checkpoint(filepath, self.save_weights_only) def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: if current is None: diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index b8f398b9e1..6516fbcc18 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -32,9 +32,18 @@ from tests.helpers.utils import no_warning_call def test_v1_5_0_model_checkpoint_save_checkpoint(): model_ckpt = ModelCheckpoint() - model_ckpt.save_function = lambda *_, **__: None + trainer = Trainer() + trainer.save_checkpoint = lambda *_, **__: None with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"): - model_ckpt.save_checkpoint(Trainer(), object()) + model_ckpt.save_checkpoint(trainer, object()) + + +def test_v1_5_0_model_checkpoint_save_function(): + model_ckpt = ModelCheckpoint() + with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"): + model_ckpt.save_function = lambda *_, **__: None + with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"): + _ = model_ckpt.save_function @mock.patch('pytorch_lightning.loggers.wandb.wandb') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b988dd3bec..8d84748497 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -330,9 +330,9 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files) save_last=save_last, verbose=True ) - checkpoint_callback.save_function = mock_save_function trainer = Trainer() trainer.state = TrainerState.FITTING + trainer.save_checkpoint = mock_save_function # emulate callback's calls during the training for i, loss in enumerate(losses):