Remove error when test dataloader used in test (#1495)

* remove error when test dataloader used in test

* remove error when test dataloader used in test

* remove error when test dataloader used in test

* remove error when test dataloader used in test

* remove error when test dataloader used in test

* remove error when test dataloader used in test

* fix lost model reference

* remove error when test dataloader used in test

* fix lost model reference

* moved optimizer types

* moved optimizer types

* moved optimizer types

* moved optimizer types

* moved optimizer types

* moved optimizer types

* moved optimizer types

* moved optimizer types

* added tests for warning

* fix lost model reference

* fix lost model reference

* added tests for warning

* added tests for warning

* refactoring

* refactoring

* fix imports

* refactoring

* fix imports

* refactoring

* fix tests

* fix mnist

* flake8

* review

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
William Falcon 2020-04-15 22:16:40 -04:00 committed by GitHub
parent 8322f1b039
commit 3431c62d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 588 additions and 59 deletions

View File

@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
- Removed test for no test dataloader in .fit ([#1495](https://github.com/PyTorchLightning/pytorch-lightning/pull/1495))
- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
- Removed wandb logger's `finalize` method ([#1193](https://github.com/PyTorchLightning/pytorch-lightning/pull/1193))
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))

View File

@ -11,7 +11,7 @@ from torchvision import transforms
import tests.base.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from tests.base.datasets import TestingMNIST
from tests.base.datasets import TrialMNIST
class ParityMNIST(LightningModule):
@ -42,10 +42,10 @@ class ParityMNIST(LightningModule):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(TestingMNIST(train=True,
download=True,
num_samples=500,
digits=list(range(5))),
return DataLoader(TrialMNIST(train=True,
download=True,
num_samples=500,
digits=list(range(5))),
batch_size=128)
@ -65,10 +65,11 @@ def test_pytorch_parity(tmpdir):
for pl_out, pt_out in zip(lightning_outs, manual_outs):
np.testing.assert_almost_equal(pl_out, pt_out, 5)
tutils.assert_speed_parity(pl_times, pt_times, num_epochs)
# the fist run initialize dataset (download & filter)
tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs)
def set_seed(seed):
def _set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
@ -88,7 +89,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
# set seed
seed = i
set_seed(seed)
_set_seed(seed)
# init model parts
model = MODEL()
@ -134,7 +135,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
# set seed
seed = i
set_seed(seed)
_set_seed(seed)
# init model parts
model = MODEL()

View File

@ -418,10 +418,8 @@ class TrainerEvaluationLoopMixin(ABC):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
if test_mode and len(self.test_dataloaders) > 1:
args.append(dataloader_idx)
elif not test_mode and len(self.val_dataloaders) > 1:
if (test_mode and len(self.test_dataloaders) > 1) \
or (not test_mode and len(self.val_dataloaders) > 1):
args.append(dataloader_idx)
# handle DP, DDP forward

View File

@ -21,6 +21,8 @@ class TrainerModelHooksMixin(ABC):
return False
instance_attr = getattr(model, method_name)
if not instance_attr:
return False
super_attr = getattr(super_object, method_name)
# when code pointers are different, it was implemented

View File

@ -939,11 +939,14 @@ class Trainer(
self.testing = True
if test_dataloaders is not None:
if model is not None:
if model:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
# give proper warnings if user only passed in loader without hooks
self.check_testing_model_configuration(model if model else self.model)
if model is not None:
self.model = model
self.fit(model)
@ -1012,10 +1015,25 @@ class Trainer(
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)
else:
if self.is_overriden('test_step', model):
raise MisconfigurationException('You have defined `test_step()`,'
' but have not passed in a `test_dataloader()`.')
def check_testing_model_configuration(self, model: LightningModule):
has_test_step = self.is_overriden('test_step', model)
has_test_epoch_end = self.is_overriden('test_epoch_end', model)
gave_test_loader = hasattr(model, 'test_dataloader') and model.test_dataloader()
if gave_test_loader and not has_test_step:
raise MisconfigurationException('You passed in a `test_dataloader` but did not implement `test_step()`')
if has_test_step and not gave_test_loader:
raise MisconfigurationException('You defined `test_step()` but did not implement'
' `test_dataloader` nor passed in `.fit(test_dataloaders`.')
if has_test_step and gave_test_loader and not has_test_epoch_end:
rank_zero_warn(
'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)
class _PatchDataLoader(object):

View File

@ -3,6 +3,7 @@
import torch
from tests.base.models import TestModelBase, DictHparamsModel
from tests.base.eval_model_template import EvalModelTemplate
from tests.base.mixins import (
LightEmptyTestStep,
LightValidationStepMixin,

View File

@ -49,7 +49,7 @@ class MNIST(Dataset):
cache_folder_name = 'complete'
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
normalize: tuple = (0.5, 1.0), download: bool = False):
normalize: tuple = (0.5, 1.0), download: bool = True):
super().__init__()
self.root = root
self.train = train # training set or test set
@ -111,7 +111,7 @@ def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Ten
return tensor
class TestingMNIST(MNIST):
class TrialMNIST(MNIST):
"""Constrain image dataset
Args:
@ -127,7 +127,7 @@ class TestingMNIST(MNIST):
digits: list selected MNIST digits/classes
Examples:
>>> dataset = TestingMNIST(download=True)
>>> dataset = TrialMNIST(download=True)
>>> len(dataset)
300
>>> sorted(set([d.item() for d in dataset.targets]))
@ -179,6 +179,8 @@ class TestingMNIST(MNIST):
self._download(super().cached_folder_path)
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
data, targets = torch.load(os.path.join(super().cached_folder_path, fname))
path_fname = os.path.join(super().cached_folder_path, fname)
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
data, targets = torch.load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))

View File

@ -3,7 +3,7 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from tests.base.datasets import TestingMNIST
from tests.base.datasets import TrialMNIST
# from test_models import assert_ok_test_acc, load_model, \
@ -42,10 +42,10 @@ class CoolModel(pl.LightningModule):
return [torch.optim.Adam(self.parameters(), lr=0.02)]
def train_dataloader(self):
return DataLoader(TestingMNIST(train=True, num_samples=100), batch_size=16)
return DataLoader(TrialMNIST(train=True, num_samples=100), batch_size=16)
def val_dataloader(self):
return DataLoader(TestingMNIST(train=False, num_samples=50), batch_size=16)
return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16)
def test_dataloader(self):
return DataLoader(TestingMNIST(train=False, num_samples=50), batch_size=16)
return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16)

View File

@ -0,0 +1,61 @@
from abc import ABC
from torch import optim
class ConfigureOptimizersPool(ABC):
def configure_optimizers(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
def configure_optimizers_empty(self):
return None
def configure_optimizers_lbfgs(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
def configure_optimizers_multiple_optimizers(self):
"""
return whatever optimizers we want here.
:return: list of optimizers
"""
# try no scheduler for this model (testing purposes)
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer1, optimizer2
def configure_optimizers_single_scheduler(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]
def configure_optimizers_multiple_schedulers(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]
def configure_optimizers_mixed_scheduling(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]
def configure_optimizers_reduce_lr_on_plateau(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]

View File

@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tests.base.datasets import TrialMNIST
from pytorch_lightning.core.lightning import LightningModule
from tests.base.eval_model_optimizers import ConfigureOptimizersPool
from tests.base.eval_model_test_dataloaders import TestDataloaderVariations
from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations
from tests.base.eval_model_test_steps import TestStepVariations
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
from tests.base.eval_model_train_steps import TrainingStepVariations
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
from tests.base.eval_model_valid_steps import ValidationStepVariations
from tests.base.eval_model_utils import ModelTemplateUtils
class EvalModelTemplate(
ModelTemplateUtils,
TrainingStepVariations,
ValidationStepVariations,
ValidationEpochEndVariations,
TestStepVariations,
TestEpochEndVariations,
TrainDataloaderVariations,
ValDataloaderVariations,
TestDataloaderVariations,
ConfigureOptimizersPool,
LightningModule
):
"""
This template houses all combinations of model configurations we want to test
"""
def __init__(self, hparams):
"""Pass in parsed HyperOptArgumentParser to the model."""
# init superclass
super().__init__()
self.hparams = hparams
# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)
# build model
self.__build_model()
def __build_model(self):
"""
Simple model for testing
:return:
"""
self.c_d1 = nn.Linear(
in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim
)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
self.c_d2 = nn.Linear(
in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features
)
def forward(self, x):
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def prepare_data(self):
_ = TrialMNIST(root=self.hparams.data_root, train=True, download=True)

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class TestDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, train: bool):
"""placeholder"""
def test_dataloader(self):
return self.dataloader(train=False)

View File

@ -0,0 +1,39 @@
from abc import ABC
import torch
class TestEpochEndVariations(ABC):
def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
# if returned a scalar from test_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
test_loss_mean = 0
test_acc_mean = 0
for output in outputs:
test_loss = self.get_output_metric(output, 'test_loss')
# reduce manually when using dp
if self.trainer.use_dp:
test_loss = torch.mean(test_loss)
test_loss_mean += test_loss
# reduce manually when using dp
test_acc = self.get_output_metric(output, 'test_acc')
if self.trainer.use_dp:
test_acc = torch.mean(test_acc)
test_acc_mean += test_acc
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
return result

View File

@ -0,0 +1,89 @@
from abc import ABC
from collections import OrderedDict
import torch
class TestStepVariations(ABC):
"""
Houses all variations of test steps
"""
def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {'test_loss_a': loss_test}
})
return output
def test_step_multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {'test_loss_a': loss_test}
})
return output
if batch_idx % 5 == 0:
output = OrderedDict({
f'test_loss_{dataloader_idx}': loss_test,
f'test_acc_{dataloader_idx}': test_acc,
})
return output

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class TrainDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, train: bool):
"""placeholder"""
def train_dataloader(self):
return self.dataloader(train=True)

View File

@ -0,0 +1,30 @@
from abc import ABC
from collections import OrderedDict
class TrainingStepVariations(ABC):
"""
Houses all variations of training steps
"""
def training_step(self, batch, batch_idx, optimizer_idx=None):
"""Lightning calls this inside the training loop"""
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
# calculate loss
loss_val = self.loss(y, y_hat)
# alternate possible outputs to test
if self.trainer.batch_idx % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'progress_bar': {'some_val': loss_val * loss_val},
'log': {'train_some_val': loss_val * loss_val},
})
return output
if self.trainer.batch_idx % 2 == 0:
return loss_val

View File

@ -0,0 +1,22 @@
from torch.utils.data import DataLoader
from tests.base.datasets import TrialMNIST
class ModelTemplateUtils:
def dataloader(self, train):
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
loader = DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
def get_output_metric(self, output, name):
if isinstance(output, dict):
val = output[name]
else: # if it is 2level deep -> per dataloader and per batch
val = sum(out[name] for out in output) / len(output)
return val

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class ValDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, train: bool):
"""placeholder"""
def val_dataloader(self):
return self.dataloader(train=False)

View File

@ -0,0 +1,42 @@
from abc import ABC
import torch
class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss = self.get_output_metric(output, 'val_loss')
# reduce manually when using dp
if self.trainer.use_dp or self.trainer.use_ddp2:
val_loss = torch.mean(val_loss)
val_loss_mean += val_loss
# reduce manually when using dp
val_acc = self.get_output_metric(output, 'val_acc')
if self.trainer.use_dp or self.trainer.use_ddp2:
val_acc = torch.mean(val_acc)
val_acc_mean += val_acc
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

View File

@ -0,0 +1,100 @@
from abc import ABC
from collections import OrderedDict
import torch
class ValidationStepVariations(ABC):
"""
Houses all variations of validation steps
"""
def validation_step(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
})
return output
if batch_idx % 2 == 0:
return val_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
'test_dic': {'val_loss_a': loss_val}
})
return output
def validation_step_multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Lightning calls this inside the validation loop
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
})
return output
if batch_idx % 2 == 0:
return val_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
'test_dic': {'val_loss_a': loss_val}
})
return output
if batch_idx % 5 == 0:
output = OrderedDict({
f'val_loss_{dataloader_idx}': loss_val,
f'val_acc_{dataloader_idx}': val_acc,
})
return output

View File

@ -8,7 +8,7 @@ import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from tests.base.datasets import TestingMNIST
from tests.base.datasets import TrialMNIST
try:
from test_tube import HyperOptArgumentParser
@ -38,7 +38,7 @@ class DictHparamsModel(LightningModule):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(TestingMNIST(train=True, download=True), batch_size=16)
return DataLoader(TrialMNIST(train=True, download=True), batch_size=16)
class TestModelBase(LightningModule):
@ -137,11 +137,11 @@ class TestModelBase(LightningModule):
return [optimizer], [scheduler]
def prepare_data(self):
_ = TestingMNIST(root=self.hparams.data_root, train=True, download=True)
_ = TrialMNIST(root=self.hparams.data_root, train=True, download=True)
def _dataloader(self, train):
# init data generators
dataset = TestingMNIST(root=self.hparams.data_root, train=train, download=False)
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
# when using multi-node we need to add the datasampler
batch_size = self.hparams.batch_size

View File

@ -8,7 +8,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TestTubeLogger, TensorBoardLogger
from tests.base import LightningTestModel
from tests.base import LightningTestModel, EvalModelTemplate
from tests.base.datasets import PATH_DATASETS
# generate a list of random seeds for each test

View File

@ -37,8 +37,7 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model, _ = tutils.get_default_model()
class StoreHistoryLogger(logger_class):
def __init__(self, *args, **kwargs):

View File

@ -3,6 +3,7 @@ import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base import (
TestModelBase,
LightValidationDataloader,
@ -119,36 +120,46 @@ def test_warning_on_wrong_test_settigs(tmpdir):
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
class CurrentTestModel(LightTrainDataloader,
LightTestDataloader,
TestModelBase):
pass
# check test_dataloader -> test_step
# ----------------
# if have test_dataloader should have test_step
# ----------------
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
model = EvalModelTemplate(hparams)
model.test_step = None
trainer.fit(model)
class CurrentTestModel(LightTrainDataloader,
LightTestStepMixin,
TestModelBase):
pass
# check test_dataloader + test_step -> test_epoch_end
# ----------------
# if have test_dataloader and test_step recommend test_epoch_end
# ----------------
with pytest.warns(RuntimeWarning):
model = CurrentTestModel(hparams)
trainer.fit(model)
model = EvalModelTemplate(hparams)
model.test_epoch_end = None
trainer.test(model)
class CurrentTestModel(LightTrainDataloader,
LightTestFitMultipleTestDataloadersMixin,
TestModelBase):
pass
# check test_step -> test_dataloader
# ----------------
# if have test_step and NO test_dataloader passed in tell user to pass test_dataloader
# ----------------
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)
model = EvalModelTemplate(hparams)
model.test_dataloader = lambda: None
trainer.test(model)
# ----------------
# if have test_dataloader and NO test_step tell user to implement test_step
# ----------------
with pytest.raises(MisconfigurationException):
model = EvalModelTemplate(hparams)
model.test_dataloader = lambda: None
model.test_step = None
trainer.test(model, test_dataloaders=model.dataloader(train=False))
# ----------------
# if have test_dataloader and test_step but no test_epoch_end warn user
# ----------------
with pytest.warns(RuntimeWarning):
model = EvalModelTemplate(hparams)
model.test_dataloader = lambda: None
model.test_epoch_end = None
trainer.test(model, test_dataloaders=model.dataloader(train=False))