Extend the deprecation of `Trainer(resume_from_checkpoint)` (#11334)
This commit is contained in:
parent
f6a6a25810
commit
5693a94c32
|
@ -1376,7 +1376,7 @@ By setting to False, you have to add your own distributed sampler:
|
|||
resume_from_checkpoint
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
|
||||
.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
|
||||
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.
|
||||
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class CheckpointConnector:
|
|||
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
|
||||
self.trainer = trainer
|
||||
self.resume_checkpoint_path: Optional[_PATH] = None
|
||||
# TODO: remove resume_from_checkpoint_fit_path in v1.7
|
||||
# TODO: remove resume_from_checkpoint_fit_path in v2.0
|
||||
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
|
||||
if resume_from_checkpoint is not None:
|
||||
rank_zero_deprecation(
|
||||
|
@ -100,7 +100,7 @@ class CheckpointConnector:
|
|||
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
|
||||
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
|
||||
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
|
||||
# TODO: remove resume_from_checkpoint_fit_path in v1.7
|
||||
# TODO: remove resume_from_checkpoint_fit_path in v2.0
|
||||
if (
|
||||
self.trainer.state.fn == TrainerFn.FITTING
|
||||
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path
|
||||
|
|
|
@ -371,7 +371,7 @@ class Trainer(
|
|||
training will start from the beginning of the next epoch.
|
||||
|
||||
.. deprecated:: v1.5
|
||||
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
|
||||
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
|
||||
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.
|
||||
|
||||
strategy: Supports different training strategies with aliases
|
||||
|
@ -766,7 +766,7 @@ class Trainer(
|
|||
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
|
||||
)
|
||||
|
||||
# TODO: ckpt_path only in v1.7
|
||||
# TODO: ckpt_path only in v2.0
|
||||
ckpt_path = ckpt_path or self.resume_from_checkpoint
|
||||
results = self._run(model, ckpt_path=ckpt_path)
|
||||
|
||||
|
@ -2053,8 +2053,9 @@ class Trainer(
|
|||
resume_from_checkpoint = self.checkpoint_connector.resume_from_checkpoint_fit_path
|
||||
if resume_from_checkpoint is not None:
|
||||
rank_zero_deprecation(
|
||||
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
|
||||
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead."
|
||||
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v2.0."
|
||||
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead.",
|
||||
stacklevel=5,
|
||||
)
|
||||
|
||||
return resume_from_checkpoint
|
||||
|
|
|
@ -35,7 +35,6 @@ from pytorch_lightning.plugins.environments import (
|
|||
TorchElasticEnvironment,
|
||||
)
|
||||
from pytorch_lightning.strategies import SingleDeviceStrategy
|
||||
from tests.callbacks.test_callbacks import OldStatefulCallback
|
||||
from tests.deprecated_api import _soft_unimport_module
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.datamodules import MNISTDataModule
|
||||
|
@ -414,49 +413,6 @@ def test_v1_7_0_deprecated_max_steps_none(tmpdir):
|
|||
trainer.fit_loop.max_steps = None
|
||||
|
||||
|
||||
def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
||||
# test resume_from_checkpoint still works until v1.7 deprecation
|
||||
model = BoringModel()
|
||||
callback = OldStatefulCallback(state=111)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
|
||||
trainer.fit(model)
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
|
||||
callback = OldStatefulCallback(state=222)
|
||||
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
|
||||
with pytest.deprecated_call(
|
||||
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
|
||||
):
|
||||
_ = trainer.resume_from_checkpoint
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
|
||||
trainer.validate(model=model, ckpt_path=ckpt_path)
|
||||
assert callback.state == 222
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
|
||||
with pytest.deprecated_call(
|
||||
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
|
||||
):
|
||||
trainer.fit(model)
|
||||
assert callback.state == 111
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
trainer.predict(model=model, ckpt_path=ckpt_path)
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
|
||||
# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
|
||||
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
|
||||
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
|
||||
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
|
||||
|
||||
|
||||
def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
|
||||
model = BoringModel()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in v2.0."""
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.callbacks.test_callbacks import OldStatefulCallback
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
||||
# test resume_from_checkpoint still works until v2.0 deprecation
|
||||
model = BoringModel()
|
||||
callback = OldStatefulCallback(state=111)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
|
||||
trainer.fit(model)
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
|
||||
callback = OldStatefulCallback(state=222)
|
||||
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
|
||||
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
|
||||
_ = trainer.resume_from_checkpoint
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
|
||||
trainer.validate(model=model, ckpt_path=ckpt_path)
|
||||
assert callback.state == 222
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
|
||||
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
|
||||
trainer.fit(model)
|
||||
assert callback.state == 111
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
trainer.predict(model=model, ckpt_path=ckpt_path)
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_connector.resume_checkpoint_path is None
|
||||
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
|
||||
|
||||
# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
|
||||
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
|
||||
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
|
||||
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
|
Loading…
Reference in New Issue