Validate the model input of trainer methods (#13892)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
dcb4dd55d9
commit
4ce97f37a2
|
@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
-
|
||||
- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -696,6 +696,8 @@ class Trainer(
|
|||
|
||||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
|
||||
"""
|
||||
if not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||
self.strategy.model = model
|
||||
self._call_and_handle_interrupt(
|
||||
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
||||
|
@ -776,6 +778,8 @@ class Trainer(
|
|||
:meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc.
|
||||
The length of the list corresponds to the number of validation dataloaders used.
|
||||
"""
|
||||
if model is not None and not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||
self.strategy.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
||||
|
||||
|
@ -864,6 +868,8 @@ class Trainer(
|
|||
:meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc.
|
||||
The length of the list corresponds to the number of test dataloaders used.
|
||||
"""
|
||||
if model is not None and not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||
self.strategy.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
||||
|
||||
|
@ -951,6 +957,8 @@ class Trainer(
|
|||
Returns:
|
||||
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
|
||||
"""
|
||||
if model is not None and not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||
self.strategy.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(
|
||||
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
|
||||
|
@ -1033,6 +1041,9 @@ class Trainer(
|
|||
|
||||
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
|
||||
"""
|
||||
if not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
|
||||
|
||||
Trainer._log_api_event("tune")
|
||||
|
||||
self.state.fn = TrainerFn.TUNING
|
||||
|
|
|
@ -20,12 +20,14 @@ from argparse import Namespace
|
|||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from re import escape
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, call, patch
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.multiprocessing import ProcessRaisedException
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
|
@ -71,6 +73,21 @@ else:
|
|||
torch_test_assert_close = torch.testing.assert_allclose
|
||||
|
||||
|
||||
def test_trainer_error_when_input_not_lightning_module():
|
||||
"""Test that a useful error gets raised when the Trainer methods receive something other than a
|
||||
LightningModule."""
|
||||
trainer = Trainer()
|
||||
|
||||
for method in ("fit", "validate", "test", "predict"):
|
||||
with pytest.raises(TypeError, match=escape(f"`Trainer.{method}()` requires a `LightningModule`, got: Linear")):
|
||||
run_method = getattr(trainer, method)
|
||||
run_method(nn.Linear(2, 2))
|
||||
|
||||
trainer = Trainer(auto_lr_find=True, auto_scale_batch_size=True)
|
||||
with pytest.raises(TypeError, match=escape("`Trainer.tune()` requires a `LightningModule`, got: Linear")):
|
||||
trainer.tune(nn.Linear(2, 2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url_ckpt", [True, False])
|
||||
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
|
||||
|
|
Loading…
Reference in New Issue