Extend the deprecation of `Trainer(resume_from_checkpoint)` (#11334)

This commit is contained in:
Carlos Mocholí 2022-01-06 13:18:37 +01:00 committed by GitHub
parent f6a6a25810
commit 5693a94c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 51 deletions

View File

@ -1376,7 +1376,7 @@ By setting to False, you have to add your own distributed sampler:
resume_from_checkpoint
^^^^^^^^^^^^^^^^^^^^^^
.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.

View File

@ -39,7 +39,7 @@ class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path: Optional[_PATH] = None
# TODO: remove resume_from_checkpoint_fit_path in v1.7
# TODO: remove resume_from_checkpoint_fit_path in v2.0
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
if resume_from_checkpoint is not None:
rank_zero_deprecation(
@ -100,7 +100,7 @@ class CheckpointConnector:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
# TODO: remove resume_from_checkpoint_fit_path in v1.7
# TODO: remove resume_from_checkpoint_fit_path in v2.0
if (
self.trainer.state.fn == TrainerFn.FITTING
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path

View File

@ -371,7 +371,7 @@ class Trainer(
training will start from the beginning of the next epoch.
.. deprecated:: v1.5
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.
strategy: Supports different training strategies with aliases
@ -766,7 +766,7 @@ class Trainer(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)
# TODO: ckpt_path only in v1.7
# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
results = self._run(model, ckpt_path=ckpt_path)
@ -2053,8 +2053,9 @@ class Trainer(
resume_from_checkpoint = self.checkpoint_connector.resume_from_checkpoint_fit_path
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead."
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v2.0."
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead.",
stacklevel=5,
)
return resume_from_checkpoint

View File

@ -35,7 +35,6 @@ from pytorch_lightning.plugins.environments import (
TorchElasticEnvironment,
)
from pytorch_lightning.strategies import SingleDeviceStrategy
from tests.callbacks.test_callbacks import OldStatefulCallback
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
@ -414,49 +413,6 @@ def test_v1_7_0_deprecated_max_steps_none(tmpdir):
trainer.fit_loop.max_steps = None
def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v1.7 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
callback = OldStatefulCallback(state=222)
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
_ = trainer.resume_from_checkpoint
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.validate(model=model, ckpt_path=ckpt_path)
assert callback.state == 222
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
trainer.fit(model)
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.predict(model=model, ckpt_path=ckpt_path)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.fit(model)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
model = BoringModel()
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
model = BoringModel()
lr_monitor = LearningRateMonitor()

View File

@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v2.0."""
import pytest
from pytorch_lightning import Trainer
from tests.callbacks.test_callbacks import OldStatefulCallback
from tests.helpers import BoringModel
def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v2.0 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
callback = OldStatefulCallback(state=222)
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
_ = trainer.resume_from_checkpoint
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.validate(model=model, ckpt_path=ckpt_path)
assert callback.state == 222
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
trainer.fit(model)
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.predict(model=model, ckpt_path=ckpt_path)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.fit(model)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
model = BoringModel()
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")