lightning/tests/base/__init__.py

62 lines
1.9 KiB
Python

"""Models for testing."""
import torch
from tests.base.models import TestModelBase, DictHparamsModel
from tests.base.eval_model_template import EvalModelTemplate
from tests.base.mixins import (
LightEmptyTestStep,
LightValidationStepMixin,
LightValidationMixin,
LightValidationStepMultipleDataloadersMixin,
LightValidationMultipleDataloadersMixin,
LightTestStepMixin,
LightTestMixin,
LightTestStepMultipleDataloadersMixin,
LightTestMultipleDataloadersMixin,
LightTestFitSingleTestDataloadersMixin,
LightTestFitMultipleTestDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightValidationDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin,
LightTestNoneOptimizerMixin,
LightZeroLenDataloader
)
class LightningTestModel(LightTrainDataloader,
LightValidationMixin,
LightTestMixin,
TestModelBase):
"""Most common test case. Validation and test dataloaders."""
def on_training_metrics(self, logs):
logs['some_tensor_to_test'] = torch.rand(1)
class LightningTestModelWithoutHyperparametersArg(LightningTestModel):
"""Without hparams argument in constructor """
def __init__(self):
import tests.base.utils as tutils
# the user loads the hparams in some other way
hparams = tutils.get_default_hparams()
super().__init__(hparams)
class LightningTestModelWithUnusedHyperparametersArg(LightningTestModelWithoutHyperparametersArg):
"""It has hparams argument in constructor but is not used."""
def __init__(self, hparams):
super().__init__()