diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index a95a88a420..05c7cda5d7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892)) - diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 561fe799f1..01d13849f8 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -696,6 +696,8 @@ class Trainer( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ + if not isinstance(model, pl.LightningModule): + raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy.model = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path @@ -776,6 +778,8 @@ class Trainer( :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ + if model is not None and not isinstance(model, pl.LightningModule): + raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) @@ -864,6 +868,8 @@ class Trainer( :meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ + if model is not None and not isinstance(model, pl.LightningModule): + raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) @@ -951,6 +957,8 @@ class Trainer( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ + if model is not None and not isinstance(model, pl.LightningModule): + raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path @@ -1033,6 +1041,9 @@ class Trainer( lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` """ + if not isinstance(model, pl.LightningModule): + raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}") + Trainer._log_api_event("tune") self.state.fn = TrainerFn.TUNING diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index ecc0ad724e..f868dcc353 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -20,12 +20,14 @@ from argparse import Namespace from contextlib import nullcontext from copy import deepcopy from pathlib import Path +from re import escape from unittest import mock from unittest.mock import ANY, call, patch import cloudpickle import pytest import torch +import torch.nn as nn from torch.multiprocessing import ProcessRaisedException from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD @@ -71,6 +73,21 @@ else: torch_test_assert_close = torch.testing.assert_allclose +def test_trainer_error_when_input_not_lightning_module(): + """Test that a useful error gets raised when the Trainer methods receive something other than a + LightningModule.""" + trainer = Trainer() + + for method in ("fit", "validate", "test", "predict"): + with pytest.raises(TypeError, match=escape(f"`Trainer.{method}()` requires a `LightningModule`, got: Linear")): + run_method = getattr(trainer, method) + run_method(nn.Linear(2, 2)) + + trainer = Trainer(auto_lr_find=True, auto_scale_batch_size=True) + with pytest.raises(TypeError, match=escape("`Trainer.tune()` requires a `LightningModule`, got: Linear")): + trainer.tune(nn.Linear(2, 2)) + + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently."""