Remove error when test dataloader used in test (#1495)
* remove error when test dataloader used in test * remove error when test dataloader used in test * remove error when test dataloader used in test * remove error when test dataloader used in test * remove error when test dataloader used in test * remove error when test dataloader used in test * fix lost model reference * remove error when test dataloader used in test * fix lost model reference * moved optimizer types * moved optimizer types * moved optimizer types * moved optimizer types * moved optimizer types * moved optimizer types * moved optimizer types * moved optimizer types * added tests for warning * fix lost model reference * fix lost model reference * added tests for warning * added tests for warning * refactoring * refactoring * fix imports * refactoring * fix imports * refactoring * fix tests * fix mnist * flake8 * review Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
8322f1b039
commit
3431c62d41
|
@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
- Removed test for no test dataloader in .fit ([#1495](https://github.com/PyTorchLightning/pytorch-lightning/pull/1495))
|
||||
- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
|
||||
- Removed wandb logger's `finalize` method ([#1193](https://github.com/PyTorchLightning/pytorch-lightning/pull/1193))
|
||||
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))
|
||||
|
|
|
@ -11,7 +11,7 @@ from torchvision import transforms
|
|||
import tests.base.utils as tutils
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from tests.base.datasets import TestingMNIST
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
|
||||
class ParityMNIST(LightningModule):
|
||||
|
@ -42,10 +42,10 @@ class ParityMNIST(LightningModule):
|
|||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TestingMNIST(train=True,
|
||||
download=True,
|
||||
num_samples=500,
|
||||
digits=list(range(5))),
|
||||
return DataLoader(TrialMNIST(train=True,
|
||||
download=True,
|
||||
num_samples=500,
|
||||
digits=list(range(5))),
|
||||
batch_size=128)
|
||||
|
||||
|
||||
|
@ -65,10 +65,11 @@ def test_pytorch_parity(tmpdir):
|
|||
for pl_out, pt_out in zip(lightning_outs, manual_outs):
|
||||
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
||||
|
||||
tutils.assert_speed_parity(pl_times, pt_times, num_epochs)
|
||||
# the fist run initialize dataset (download & filter)
|
||||
tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
def _set_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
|
@ -88,7 +89,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
|
||||
# set seed
|
||||
seed = i
|
||||
set_seed(seed)
|
||||
_set_seed(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
|
@ -134,7 +135,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
|
||||
# set seed
|
||||
seed = i
|
||||
set_seed(seed)
|
||||
_set_seed(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
|
|
|
@ -418,10 +418,8 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
if test_mode and len(self.test_dataloaders) > 1:
|
||||
args.append(dataloader_idx)
|
||||
|
||||
elif not test_mode and len(self.val_dataloaders) > 1:
|
||||
if (test_mode and len(self.test_dataloaders) > 1) \
|
||||
or (not test_mode and len(self.val_dataloaders) > 1):
|
||||
args.append(dataloader_idx)
|
||||
|
||||
# handle DP, DDP forward
|
||||
|
|
|
@ -21,6 +21,8 @@ class TrainerModelHooksMixin(ABC):
|
|||
return False
|
||||
|
||||
instance_attr = getattr(model, method_name)
|
||||
if not instance_attr:
|
||||
return False
|
||||
super_attr = getattr(super_object, method_name)
|
||||
|
||||
# when code pointers are different, it was implemented
|
||||
|
|
|
@ -939,11 +939,14 @@ class Trainer(
|
|||
self.testing = True
|
||||
|
||||
if test_dataloaders is not None:
|
||||
if model is not None:
|
||||
if model:
|
||||
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
||||
else:
|
||||
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
|
||||
|
||||
# give proper warnings if user only passed in loader without hooks
|
||||
self.check_testing_model_configuration(model if model else self.model)
|
||||
|
||||
if model is not None:
|
||||
self.model = model
|
||||
self.fit(model)
|
||||
|
@ -1012,10 +1015,25 @@ class Trainer(
|
|||
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
|
||||
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
|
||||
)
|
||||
else:
|
||||
if self.is_overriden('test_step', model):
|
||||
raise MisconfigurationException('You have defined `test_step()`,'
|
||||
' but have not passed in a `test_dataloader()`.')
|
||||
|
||||
def check_testing_model_configuration(self, model: LightningModule):
|
||||
|
||||
has_test_step = self.is_overriden('test_step', model)
|
||||
has_test_epoch_end = self.is_overriden('test_epoch_end', model)
|
||||
gave_test_loader = hasattr(model, 'test_dataloader') and model.test_dataloader()
|
||||
|
||||
if gave_test_loader and not has_test_step:
|
||||
raise MisconfigurationException('You passed in a `test_dataloader` but did not implement `test_step()`')
|
||||
|
||||
if has_test_step and not gave_test_loader:
|
||||
raise MisconfigurationException('You defined `test_step()` but did not implement'
|
||||
' `test_dataloader` nor passed in `.fit(test_dataloaders`.')
|
||||
|
||||
if has_test_step and gave_test_loader and not has_test_epoch_end:
|
||||
rank_zero_warn(
|
||||
'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
|
||||
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
|
||||
)
|
||||
|
||||
|
||||
class _PatchDataLoader(object):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import torch
|
||||
|
||||
from tests.base.models import TestModelBase, DictHparamsModel
|
||||
from tests.base.eval_model_template import EvalModelTemplate
|
||||
from tests.base.mixins import (
|
||||
LightEmptyTestStep,
|
||||
LightValidationStepMixin,
|
||||
|
|
|
@ -49,7 +49,7 @@ class MNIST(Dataset):
|
|||
cache_folder_name = 'complete'
|
||||
|
||||
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
|
||||
normalize: tuple = (0.5, 1.0), download: bool = False):
|
||||
normalize: tuple = (0.5, 1.0), download: bool = True):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.train = train # training set or test set
|
||||
|
@ -111,7 +111,7 @@ def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Ten
|
|||
return tensor
|
||||
|
||||
|
||||
class TestingMNIST(MNIST):
|
||||
class TrialMNIST(MNIST):
|
||||
"""Constrain image dataset
|
||||
|
||||
Args:
|
||||
|
@ -127,7 +127,7 @@ class TestingMNIST(MNIST):
|
|||
digits: list selected MNIST digits/classes
|
||||
|
||||
Examples:
|
||||
>>> dataset = TestingMNIST(download=True)
|
||||
>>> dataset = TrialMNIST(download=True)
|
||||
>>> len(dataset)
|
||||
300
|
||||
>>> sorted(set([d.item() for d in dataset.targets]))
|
||||
|
@ -179,6 +179,8 @@ class TestingMNIST(MNIST):
|
|||
self._download(super().cached_folder_path)
|
||||
|
||||
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
|
||||
data, targets = torch.load(os.path.join(super().cached_folder_path, fname))
|
||||
path_fname = os.path.join(super().cached_folder_path, fname)
|
||||
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
|
||||
data, targets = torch.load(path_fname)
|
||||
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
|
||||
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))
|
||||
|
|
|
@ -3,7 +3,7 @@ from torch.nn import functional as F
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from tests.base.datasets import TestingMNIST
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
|
||||
# from test_models import assert_ok_test_acc, load_model, \
|
||||
|
@ -42,10 +42,10 @@ class CoolModel(pl.LightningModule):
|
|||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TestingMNIST(train=True, num_samples=100), batch_size=16)
|
||||
return DataLoader(TrialMNIST(train=True, num_samples=100), batch_size=16)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(TestingMNIST(train=False, num_samples=50), batch_size=16)
|
||||
return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(TestingMNIST(train=False, num_samples=50), batch_size=16)
|
||||
return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from abc import ABC
|
||||
|
||||
from torch import optim
|
||||
|
||||
|
||||
class ConfigureOptimizersPool(ABC):
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
return whatever optimizers we want here.
|
||||
:return: list of optimizers
|
||||
"""
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def configure_optimizers_empty(self):
|
||||
return None
|
||||
|
||||
def configure_optimizers_lbfgs(self):
|
||||
"""
|
||||
return whatever optimizers we want here.
|
||||
:return: list of optimizers
|
||||
"""
|
||||
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def configure_optimizers_multiple_optimizers(self):
|
||||
"""
|
||||
return whatever optimizers we want here.
|
||||
:return: list of optimizers
|
||||
"""
|
||||
# try no scheduler for this model (testing purposes)
|
||||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
return optimizer1, optimizer2
|
||||
|
||||
def configure_optimizers_single_scheduler(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def configure_optimizers_multiple_schedulers(self):
|
||||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
|
||||
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
|
||||
|
||||
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]
|
||||
|
||||
def configure_optimizers_mixed_scheduling(self):
|
||||
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1)
|
||||
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
|
||||
|
||||
return [optimizer1, optimizer2], \
|
||||
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]
|
||||
|
||||
def configure_optimizers_reduce_lr_on_plateau(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
return [optimizer], [lr_scheduler]
|
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.base.datasets import TrialMNIST
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from tests.base.eval_model_optimizers import ConfigureOptimizersPool
|
||||
from tests.base.eval_model_test_dataloaders import TestDataloaderVariations
|
||||
from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations
|
||||
from tests.base.eval_model_test_steps import TestStepVariations
|
||||
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
|
||||
from tests.base.eval_model_train_steps import TrainingStepVariations
|
||||
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
|
||||
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
|
||||
from tests.base.eval_model_valid_steps import ValidationStepVariations
|
||||
from tests.base.eval_model_utils import ModelTemplateUtils
|
||||
|
||||
|
||||
class EvalModelTemplate(
|
||||
ModelTemplateUtils,
|
||||
TrainingStepVariations,
|
||||
ValidationStepVariations,
|
||||
ValidationEpochEndVariations,
|
||||
TestStepVariations,
|
||||
TestEpochEndVariations,
|
||||
TrainDataloaderVariations,
|
||||
ValDataloaderVariations,
|
||||
TestDataloaderVariations,
|
||||
ConfigureOptimizersPool,
|
||||
LightningModule
|
||||
):
|
||||
"""
|
||||
This template houses all combinations of model configurations we want to test
|
||||
"""
|
||||
def __init__(self, hparams):
|
||||
"""Pass in parsed HyperOptArgumentParser to the model."""
|
||||
# init superclass
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
|
||||
# if you specify an example input, the summary will show input/output for each layer
|
||||
self.example_input_array = torch.rand(5, 28 * 28)
|
||||
|
||||
# build model
|
||||
self.__build_model()
|
||||
|
||||
def __build_model(self):
|
||||
"""
|
||||
Simple model for testing
|
||||
:return:
|
||||
"""
|
||||
self.c_d1 = nn.Linear(
|
||||
in_features=self.hparams.in_features,
|
||||
out_features=self.hparams.hidden_dim
|
||||
)
|
||||
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
|
||||
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
|
||||
|
||||
self.c_d2 = nn.Linear(
|
||||
in_features=self.hparams.hidden_dim,
|
||||
out_features=self.hparams.out_features
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
|
||||
x = self.c_d2(x)
|
||||
logits = F.log_softmax(x, dim=1)
|
||||
|
||||
return logits
|
||||
|
||||
def loss(self, labels, logits):
|
||||
nll = F.nll_loss(logits, labels)
|
||||
return nll
|
||||
|
||||
def prepare_data(self):
|
||||
_ = TrialMNIST(root=self.hparams.data_root, train=True, download=True)
|
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TestDataloaderVariations(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def dataloader(self, train: bool):
|
||||
"""placeholder"""
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloader(train=False)
|
|
@ -0,0 +1,39 @@
|
|||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestEpochEndVariations(ABC):
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at the end of validation to aggregate outputs
|
||||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
# if returned a scalar from test_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
test_loss_mean = 0
|
||||
test_acc_mean = 0
|
||||
for output in outputs:
|
||||
test_loss = self.get_output_metric(output, 'test_loss')
|
||||
|
||||
# reduce manually when using dp
|
||||
if self.trainer.use_dp:
|
||||
test_loss = torch.mean(test_loss)
|
||||
test_loss_mean += test_loss
|
||||
|
||||
# reduce manually when using dp
|
||||
test_acc = self.get_output_metric(output, 'test_acc')
|
||||
if self.trainer.use_dp:
|
||||
test_acc = torch.mean(test_acc)
|
||||
|
||||
test_acc_mean += test_acc
|
||||
|
||||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
|
||||
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
||||
return result
|
|
@ -0,0 +1,89 @@
|
|||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestStepVariations(ABC):
|
||||
"""
|
||||
Houses all variations of test steps
|
||||
"""
|
||||
def test_step(self, batch, batch_idx, *args, **kwargs):
|
||||
"""
|
||||
Default, baseline test_step
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_test = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
test_acc = torch.tensor(test_acc)
|
||||
|
||||
test_acc = test_acc.type_as(x)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
'test_dic': {'test_loss_a': loss_test}
|
||||
})
|
||||
return output
|
||||
|
||||
def test_step_multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
|
||||
"""
|
||||
Default, baseline test_step
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_test = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
test_acc = torch.tensor(test_acc)
|
||||
|
||||
test_acc = test_acc.type_as(x)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
'test_dic': {'test_loss_a': loss_test}
|
||||
})
|
||||
return output
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'test_loss_{dataloader_idx}': loss_test,
|
||||
f'test_acc_{dataloader_idx}': test_acc,
|
||||
})
|
||||
return output
|
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TrainDataloaderVariations(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def dataloader(self, train: bool):
|
||||
"""placeholder"""
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataloader(train=True)
|
|
@ -0,0 +1,30 @@
|
|||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TrainingStepVariations(ABC):
|
||||
"""
|
||||
Houses all variations of training steps
|
||||
"""
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
"""Lightning calls this inside the training loop"""
|
||||
# forward pass
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = self(x)
|
||||
|
||||
# calculate loss
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if self.trainer.batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
'progress_bar': {'some_val': loss_val * loss_val},
|
||||
'log': {'train_some_val': loss_val * loss_val},
|
||||
})
|
||||
return output
|
||||
|
||||
if self.trainer.batch_idx % 2 == 0:
|
||||
return loss_val
|
|
@ -0,0 +1,22 @@
|
|||
from torch.utils.data import DataLoader
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
|
||||
class ModelTemplateUtils:
|
||||
|
||||
def dataloader(self, train):
|
||||
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_output_metric(self, output, name):
|
||||
if isinstance(output, dict):
|
||||
val = output[name]
|
||||
else: # if it is 2level deep -> per dataloader and per batch
|
||||
val = sum(out[name] for out in output) / len(output)
|
||||
return val
|
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ValDataloaderVariations(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def dataloader(self, train: bool):
|
||||
"""placeholder"""
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.dataloader(train=False)
|
|
@ -0,0 +1,42 @@
|
|||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ValidationEpochEndVariations(ABC):
|
||||
"""
|
||||
Houses all variations of validation_epoch_end steps
|
||||
"""
|
||||
def validation_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at the end of validation to aggregate outputs
|
||||
|
||||
Args:
|
||||
outputs: list of individual outputs of each validation step
|
||||
"""
|
||||
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
val_loss_mean = 0
|
||||
val_acc_mean = 0
|
||||
for output in outputs:
|
||||
val_loss = self.get_output_metric(output, 'val_loss')
|
||||
|
||||
# reduce manually when using dp
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
val_loss = torch.mean(val_loss)
|
||||
val_loss_mean += val_loss
|
||||
|
||||
# reduce manually when using dp
|
||||
val_acc = self.get_output_metric(output, 'val_acc')
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
val_acc = torch.mean(val_acc)
|
||||
|
||||
val_acc_mean += val_acc
|
||||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
|
||||
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
||||
return results
|
|
@ -0,0 +1,100 @@
|
|||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
|
||||
|
||||
class ValidationStepVariations(ABC):
|
||||
"""
|
||||
Houses all variations of validation steps
|
||||
"""
|
||||
def validation_step(self, batch, batch_idx, *args, **kwargs):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
if self.on_gpu:
|
||||
val_acc = val_acc.cuda(loss_val.device.index)
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
val_acc = val_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return val_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
'test_dic': {'val_loss_a': loss_val}
|
||||
})
|
||||
return output
|
||||
|
||||
def validation_step_multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
if self.on_gpu:
|
||||
val_acc = val_acc.cuda(loss_val.device.index)
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
val_acc = val_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return val_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
'test_dic': {'val_loss_a': loss_val}
|
||||
})
|
||||
return output
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'val_loss_{dataloader_idx}': loss_val,
|
||||
f'val_acc_{dataloader_idx}': val_acc,
|
||||
})
|
||||
return output
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.base.datasets import TestingMNIST
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
try:
|
||||
from test_tube import HyperOptArgumentParser
|
||||
|
@ -38,7 +38,7 @@ class DictHparamsModel(LightningModule):
|
|||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TestingMNIST(train=True, download=True), batch_size=16)
|
||||
return DataLoader(TrialMNIST(train=True, download=True), batch_size=16)
|
||||
|
||||
|
||||
class TestModelBase(LightningModule):
|
||||
|
@ -137,11 +137,11 @@ class TestModelBase(LightningModule):
|
|||
return [optimizer], [scheduler]
|
||||
|
||||
def prepare_data(self):
|
||||
_ = TestingMNIST(root=self.hparams.data_root, train=True, download=True)
|
||||
_ = TrialMNIST(root=self.hparams.data_root, train=True, download=True)
|
||||
|
||||
def _dataloader(self, train):
|
||||
# init data generators
|
||||
dataset = TestingMNIST(root=self.hparams.data_root, train=train, download=False)
|
||||
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
|
||||
|
||||
# when using multi-node we need to add the datasampler
|
||||
batch_size = self.hparams.batch_size
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TestTubeLogger, TensorBoardLogger
|
||||
from tests.base import LightningTestModel
|
||||
from tests.base import LightningTestModel, EvalModelTemplate
|
||||
from tests.base.datasets import PATH_DATASETS
|
||||
|
||||
# generate a list of random seeds for each test
|
||||
|
|
|
@ -37,8 +37,7 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
|||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model, _ = tutils.get_default_model()
|
||||
|
||||
class StoreHistoryLogger(logger_class):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
LightValidationDataloader,
|
||||
|
@ -119,36 +120,46 @@ def test_warning_on_wrong_test_settigs(tmpdir):
|
|||
"""
|
||||
tutils.reset_seed()
|
||||
hparams = tutils.get_default_hparams()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
||||
|
||||
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader,
|
||||
LightTestDataloader,
|
||||
TestModelBase):
|
||||
pass
|
||||
|
||||
# check test_dataloader -> test_step
|
||||
# ----------------
|
||||
# if have test_dataloader should have test_step
|
||||
# ----------------
|
||||
with pytest.raises(MisconfigurationException):
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_step = None
|
||||
trainer.fit(model)
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader,
|
||||
LightTestStepMixin,
|
||||
TestModelBase):
|
||||
pass
|
||||
|
||||
# check test_dataloader + test_step -> test_epoch_end
|
||||
# ----------------
|
||||
# if have test_dataloader and test_step recommend test_epoch_end
|
||||
# ----------------
|
||||
with pytest.warns(RuntimeWarning):
|
||||
model = CurrentTestModel(hparams)
|
||||
trainer.fit(model)
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_epoch_end = None
|
||||
trainer.test(model)
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader,
|
||||
LightTestFitMultipleTestDataloadersMixin,
|
||||
TestModelBase):
|
||||
pass
|
||||
|
||||
# check test_step -> test_dataloader
|
||||
# ----------------
|
||||
# if have test_step and NO test_dataloader passed in tell user to pass test_dataloader
|
||||
# ----------------
|
||||
with pytest.raises(MisconfigurationException):
|
||||
model = CurrentTestModel(hparams)
|
||||
trainer.fit(model)
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_dataloader = lambda: None
|
||||
trainer.test(model)
|
||||
|
||||
# ----------------
|
||||
# if have test_dataloader and NO test_step tell user to implement test_step
|
||||
# ----------------
|
||||
with pytest.raises(MisconfigurationException):
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_dataloader = lambda: None
|
||||
model.test_step = None
|
||||
trainer.test(model, test_dataloaders=model.dataloader(train=False))
|
||||
|
||||
# ----------------
|
||||
# if have test_dataloader and test_step but no test_epoch_end warn user
|
||||
# ----------------
|
||||
with pytest.warns(RuntimeWarning):
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.test_dataloader = lambda: None
|
||||
model.test_epoch_end = None
|
||||
trainer.test(model, test_dataloaders=model.dataloader(train=False))
|
||||
|
|
Loading…
Reference in New Issue