From 5693a94c320297cf007f3bfd13ce4d7deeb1954a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 6 Jan 2022 13:18:37 +0100 Subject: [PATCH] Extend the deprecation of `Trainer(resume_from_checkpoint)` (#11334) --- docs/source/common/trainer.rst | 2 +- .../connectors/checkpoint_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 9 +-- tests/deprecated_api/test_remove_1-7.py | 44 -------------- tests/deprecated_api/test_remove_2-0.py | 58 +++++++++++++++++++ 5 files changed, 66 insertions(+), 51 deletions(-) create mode 100644 tests/deprecated_api/test_remove_2-0.py diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index efa720f5eb..8ac067fc55 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1376,7 +1376,7 @@ By setting to False, you have to add your own distributed sampler: resume_from_checkpoint ^^^^^^^^^^^^^^^^^^^^^^ -.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. +.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0. Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead. diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index da63750f59..080dcd519d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -39,7 +39,7 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer self.resume_checkpoint_path: Optional[_PATH] = None - # TODO: remove resume_from_checkpoint_fit_path in v1.7 + # TODO: remove resume_from_checkpoint_fit_path in v2.0 self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint if resume_from_checkpoint is not None: rank_zero_deprecation( @@ -100,7 +100,7 @@ class CheckpointConnector: rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING): rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}") - # TODO: remove resume_from_checkpoint_fit_path in v1.7 + # TODO: remove resume_from_checkpoint_fit_path in v2.0 if ( self.trainer.state.fn == TrainerFn.FITTING and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9573f94e4e..7975981d8b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -371,7 +371,7 @@ class Trainer( training will start from the beginning of the next epoch. .. deprecated:: v1.5 - ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. + ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0. Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead. strategy: Supports different training strategies with aliases @@ -766,7 +766,7 @@ class Trainer( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - # TODO: ckpt_path only in v1.7 + # TODO: ckpt_path only in v2.0 ckpt_path = ckpt_path or self.resume_from_checkpoint results = self._run(model, ckpt_path=ckpt_path) @@ -2053,8 +2053,9 @@ class Trainer( resume_from_checkpoint = self.checkpoint_connector.resume_from_checkpoint_fit_path if resume_from_checkpoint is not None: rank_zero_deprecation( - "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." - " Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead." + "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v2.0." + " Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead.", + stacklevel=5, ) return resume_from_checkpoint diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 9bbfa2c056..8cfaa843a9 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -35,7 +35,6 @@ from pytorch_lightning.plugins.environments import ( TorchElasticEnvironment, ) from pytorch_lightning.strategies import SingleDeviceStrategy -from tests.callbacks.test_callbacks import OldStatefulCallback from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule @@ -414,49 +413,6 @@ def test_v1_7_0_deprecated_max_steps_none(tmpdir): trainer.fit_loop.max_steps = None -def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): - # test resume_from_checkpoint still works until v1.7 deprecation - model = BoringModel() - callback = OldStatefulCallback(state=111) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) - trainer.fit(model) - ckpt_path = trainer.checkpoint_callback.best_model_path - - callback = OldStatefulCallback(state=222) - with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) - with pytest.deprecated_call( - match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." - ): - _ = trainer.resume_from_checkpoint - assert trainer.checkpoint_connector.resume_checkpoint_path is None - assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path - trainer.validate(model=model, ckpt_path=ckpt_path) - assert callback.state == 222 - assert trainer.checkpoint_connector.resume_checkpoint_path is None - assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path - with pytest.deprecated_call( - match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." - ): - trainer.fit(model) - assert callback.state == 111 - assert trainer.checkpoint_connector.resume_checkpoint_path is None - assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None - trainer.predict(model=model, ckpt_path=ckpt_path) - assert trainer.checkpoint_connector.resume_checkpoint_path is None - assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None - trainer.fit(model) - assert trainer.checkpoint_connector.resume_checkpoint_path is None - assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None - - # test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path - model = BoringModel() - with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): - trainer = Trainer(resume_from_checkpoint="trainer_arg_path") - with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."): - trainer.fit(model, ckpt_path="fit_arg_ckpt_path") - - def test_v1_7_0_deprecate_lr_sch_names(tmpdir): model = BoringModel() lr_monitor = LearningRateMonitor() diff --git a/tests/deprecated_api/test_remove_2-0.py b/tests/deprecated_api/test_remove_2-0.py new file mode 100644 index 0000000000..ed0520b9b8 --- /dev/null +++ b/tests/deprecated_api/test_remove_2-0.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v2.0.""" +import pytest + +from pytorch_lightning import Trainer +from tests.callbacks.test_callbacks import OldStatefulCallback +from tests.helpers import BoringModel + + +def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir): + # test resume_from_checkpoint still works until v2.0 deprecation + model = BoringModel() + callback = OldStatefulCallback(state=111) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + + callback = OldStatefulCallback(state=222) + with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) + with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"): + _ = trainer.resume_from_checkpoint + assert trainer.checkpoint_connector.resume_checkpoint_path is None + assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path + trainer.validate(model=model, ckpt_path=ckpt_path) + assert callback.state == 222 + assert trainer.checkpoint_connector.resume_checkpoint_path is None + assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path + with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"): + trainer.fit(model) + assert callback.state == 111 + assert trainer.checkpoint_connector.resume_checkpoint_path is None + assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None + trainer.predict(model=model, ckpt_path=ckpt_path) + assert trainer.checkpoint_connector.resume_checkpoint_path is None + assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None + trainer.fit(model) + assert trainer.checkpoint_connector.resume_checkpoint_path is None + assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None + + # test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path + model = BoringModel() + with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): + trainer = Trainer(resume_from_checkpoint="trainer_arg_path") + with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."): + trainer.fit(model, ckpt_path="fit_arg_ckpt_path")