Validate the model input of trainer methods (#13892)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-08-03 15:38:42 +02:00 committed by GitHub
parent dcb4dd55d9
commit 4ce97f37a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 1 deletions

View File

@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
-
- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
-

View File

@ -696,6 +696,8 @@ class Trainer(
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy.model = model
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
@ -776,6 +778,8 @@ class Trainer(
:meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc.
The length of the list corresponds to the number of validation dataloaders used.
"""
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy.model = model or self.lightning_module
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
@ -864,6 +868,8 @@ class Trainer(
:meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc.
The length of the list corresponds to the number of test dataloaders used.
"""
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy.model = model or self.lightning_module
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
@ -951,6 +957,8 @@ class Trainer(
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
"""
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy.model = model or self.lightning_module
return self._call_and_handle_interrupt(
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
@ -1033,6 +1041,9 @@ class Trainer(
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
"""
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
Trainer._log_api_event("tune")
self.state.fn = TrainerFn.TUNING

View File

@ -20,12 +20,14 @@ from argparse import Namespace
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from re import escape
from unittest import mock
from unittest.mock import ANY, call, patch
import cloudpickle
import pytest
import torch
import torch.nn as nn
from torch.multiprocessing import ProcessRaisedException
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
@ -71,6 +73,21 @@ else:
torch_test_assert_close = torch.testing.assert_allclose
def test_trainer_error_when_input_not_lightning_module():
"""Test that a useful error gets raised when the Trainer methods receive something other than a
LightningModule."""
trainer = Trainer()
for method in ("fit", "validate", "test", "predict"):
with pytest.raises(TypeError, match=escape(f"`Trainer.{method}()` requires a `LightningModule`, got: Linear")):
run_method = getattr(trainer, method)
run_method(nn.Linear(2, 2))
trainer = Trainer(auto_lr_find=True, auto_scale_batch_size=True)
with pytest.raises(TypeError, match=escape("`Trainer.tune()` requires a `LightningModule`, got: Linear")):
trainer.tune(nn.Linear(2, 2))
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""