Do not mark LightningModule methods as abstract (#12381)
* do not mark LightningModule methods as abstract * add concrete test
This commit is contained in:
parent
ea7f444167
commit
94fe322533
|
@ -19,6 +19,7 @@ import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
from pytorch_lightning.utilities import move_data_to_device
|
from pytorch_lightning.utilities import move_data_to_device
|
||||||
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||||
|
|
||||||
|
|
||||||
|
@ -490,7 +491,7 @@ class DataHooks:
|
||||||
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
|
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
|
||||||
return {'mnist': mnist_loader, 'cifar': cifar_loader}
|
return {'mnist': mnist_loader, 'cifar': cifar_loader}
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer")
|
raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||||
|
|
||||||
def test_dataloader(self) -> EVAL_DATALOADERS:
|
def test_dataloader(self) -> EVAL_DATALOADERS:
|
||||||
r"""
|
r"""
|
||||||
|
@ -544,7 +545,7 @@ class DataHooks:
|
||||||
In the case where you return multiple test dataloaders, the :meth:`test_step`
|
In the case where you return multiple test dataloaders, the :meth:`test_step`
|
||||||
will have an argument ``dataloader_idx`` which matches the order here.
|
will have an argument ``dataloader_idx`` which matches the order here.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("`test_dataloader` must be implemented to be used with the Lightning Trainer")
|
raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||||
|
|
||||||
def val_dataloader(self) -> EVAL_DATALOADERS:
|
def val_dataloader(self) -> EVAL_DATALOADERS:
|
||||||
r"""
|
r"""
|
||||||
|
@ -595,7 +596,7 @@ class DataHooks:
|
||||||
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
|
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
|
||||||
will have an argument ``dataloader_idx`` which matches the order here.
|
will have an argument ``dataloader_idx`` which matches the order here.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("`val_dataloader` must be implemented to be used with the Lightning Trainer")
|
raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
|
||||||
|
|
||||||
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
||||||
r"""
|
r"""
|
||||||
|
@ -618,7 +619,9 @@ class DataHooks:
|
||||||
In the case where you return multiple prediction dataloaders, the :meth:`predict_step`
|
In the case where you return multiple prediction dataloaders, the :meth:`predict_step`
|
||||||
will have an argument ``dataloader_idx`` which matches the order here.
|
will have an argument ``dataloader_idx`` which matches the order here.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
|
raise MisconfigurationException(
|
||||||
|
"`predict_dataloader` must be implemented to be used with the Lightning Trainer"
|
||||||
|
)
|
||||||
|
|
||||||
def on_train_dataloader(self) -> None:
|
def on_train_dataloader(self) -> None:
|
||||||
"""Called before requesting the train dataloader.
|
"""Called before requesting the train dataloader.
|
||||||
|
|
|
@ -377,9 +377,9 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
||||||
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
|
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
|
||||||
dm.train_dataloader()
|
dm.train_dataloader()
|
||||||
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
|
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(MisconfigurationException):
|
||||||
_ = dm.val_dataloader()
|
_ = dm.val_dataloader()
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(MisconfigurationException):
|
||||||
_ = dm.test_dataloader()
|
_ = dm.test_dataloader()
|
||||||
|
|
||||||
train_ds_sequence = [ds(), ds()]
|
train_ds_sequence = [ds(), ds()]
|
||||||
|
@ -392,9 +392,9 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
||||||
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
|
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(MisconfigurationException):
|
||||||
_ = dm.val_dataloader()
|
_ = dm.val_dataloader()
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(MisconfigurationException):
|
||||||
_ = dm.test_dataloader()
|
_ = dm.test_dataloader()
|
||||||
|
|
||||||
valid_ds = ds()
|
valid_ds = ds()
|
||||||
|
@ -405,7 +405,7 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
||||||
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
||||||
dm.test_dataloader()
|
dm.test_dataloader()
|
||||||
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(MisconfigurationException):
|
||||||
_ = dm.train_dataloader()
|
_ = dm.train_dataloader()
|
||||||
|
|
||||||
valid_dss = [ds(), ds()]
|
valid_dss = [ds(), ds()]
|
||||||
|
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam, SGD
|
from torch.optim import Adam, SGD
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import LightningModule, Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||||
|
@ -26,6 +26,11 @@ from tests.helpers import BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
|
||||||
|
|
||||||
|
def test_lightning_module_not_abstract():
|
||||||
|
"""Test that the LightningModule can be instantiated (it is not an abstract class)."""
|
||||||
|
_ = LightningModule()
|
||||||
|
|
||||||
|
|
||||||
def test_property_current_epoch():
|
def test_property_current_epoch():
|
||||||
"""Test that the current_epoch in LightningModule is accessible via the Trainer."""
|
"""Test that the current_epoch in LightningModule is accessible via the Trainer."""
|
||||||
model = BoringModel()
|
model = BoringModel()
|
||||||
|
|
Loading…
Reference in New Issue