Error messages for removed Trainer mixin methods (#15065)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
fe32b39dbc
commit
b1cc740fd6
|
@ -11,12 +11,34 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
def _patch_sys_modules() -> None:
|
||||
self = sys.modules[__name__]
|
||||
sys.modules["pytorch_lightning.trainer.data_loading"] = self
|
||||
sys.modules["pytorch_lightning.trainer.optimizers"] = self
|
||||
|
||||
|
||||
class TrainerDataLoadingMixin:
|
||||
# TODO: Remove in v2.0.0
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"The `TrainerDataLoadingMixin` class was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
)
|
||||
|
||||
|
||||
class TrainerOptimizersMixin:
|
||||
# TODO: Remove in v2.0.0
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"The `TrainerOptimizersMixin` class was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
)
|
||||
|
||||
|
||||
def _gpus(_: Trainer) -> None:
|
||||
# Remove in v2.0.0
|
||||
raise AttributeError(
|
||||
|
@ -169,6 +191,30 @@ def _call_hook(_: Trainer, *__: Any, **___: Any) -> Any:
|
|||
raise NotImplementedError("`Trainer.call_hook` was deprecated in v1.6 and is no longer supported as of v1.8.")
|
||||
|
||||
|
||||
def _prepare_dataloader(_: Trainer, *__: Any, **___: Any) -> None:
|
||||
raise NotImplementedError(
|
||||
"`Trainer.prepare_dataloader` was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
)
|
||||
|
||||
|
||||
def _request_dataloader(_: Trainer, *__: Any, **___: Any) -> None:
|
||||
raise NotImplementedError(
|
||||
"`Trainer.request_dataloader` was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
)
|
||||
|
||||
|
||||
def _init_optimizers(_: Trainer, *__: Any, **___: Any) -> None:
|
||||
raise NotImplementedError("`Trainer.init_optimizers` was deprecated in v1.6 and is no longer supported as of v1.8.")
|
||||
|
||||
|
||||
def _convert_to_lightning_optimizers(_: Trainer) -> None:
|
||||
raise NotImplementedError(
|
||||
"`Trainer.convert_to_lightning_optimizers` was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
)
|
||||
|
||||
|
||||
_patch_sys_modules()
|
||||
|
||||
# Properties/Attributes
|
||||
Trainer.gpus = property(_gpus)
|
||||
Trainer.root_gpu = property(_root_gpu)
|
||||
|
@ -189,3 +235,7 @@ Trainer.verbose_evaluate = property(fget=_verbose_evaluate, fset=_verbose_evalua
|
|||
# Methods
|
||||
Trainer.run_stage = _run_stage
|
||||
Trainer.call_hook = _call_hook
|
||||
Trainer.prepare_dataloader = _prepare_dataloader
|
||||
Trainer.request_dataloader = _request_dataloader
|
||||
Trainer.init_optimizers = _init_optimizers
|
||||
Trainer.convert_to_lightning_optimizers = _convert_to_lightning_optimizers
|
||||
|
|
|
@ -75,3 +75,54 @@ def test_v2_0_0_unsupported_call_hook():
|
|||
NotImplementedError, match="`Trainer.call_hook` was deprecated in v1.6 and is no longer supported as of v1.8."
|
||||
):
|
||||
trainer.call_hook("test_hook")
|
||||
|
||||
|
||||
def test_v2_0_0_unsupported_data_loading_mixin():
|
||||
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
||||
|
||||
class CustomTrainerDataLoadingMixin(TrainerDataLoadingMixin):
|
||||
pass
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`TrainerDataLoadingMixin` class was deprecated in v1.6 and is no longer supported as of v1.8",
|
||||
):
|
||||
CustomTrainerDataLoadingMixin()
|
||||
|
||||
trainer = Trainer()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`Trainer.prepare_dataloader` was deprecated in v1.6 and is no longer supported as of v1.8.",
|
||||
):
|
||||
trainer.prepare_dataloader(None)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`Trainer.request_dataloader` was deprecated in v1.6 and is no longer supported as of v1.8.",
|
||||
):
|
||||
trainer.request_dataloader(None)
|
||||
|
||||
|
||||
def test_v2_0_0_trainer_optimizers_mixin():
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
|
||||
class CustomTrainerOptimizersMixin(TrainerOptimizersMixin):
|
||||
pass
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`TrainerOptimizersMixin` class was deprecated in v1.6 and is no longer supported as of v1.8",
|
||||
):
|
||||
CustomTrainerOptimizersMixin()
|
||||
|
||||
trainer = Trainer()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`Trainer.init_optimizers` was deprecated in v1.6 and is no longer supported as of v1.8.",
|
||||
):
|
||||
trainer.init_optimizers(None)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`Trainer.convert_to_lightning_optimizers` was deprecated in v1.6 and is no longer supported as of v1.8.",
|
||||
):
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
|
|
Loading…
Reference in New Issue