Tests: refactor trainer dataloaders (#1690)

* refactor default model

* drop redundant seeds

* refactor dataloaders tests

* fix multiple

* fix conf

* flake8

* Apply suggestions from code review

Co-authored-by: William Falcon <waf2107@columbia.edu>

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Jirka Borovec 2020-05-05 18:31:15 +02:00 committed by GitHub
parent a6de1b8d75
commit 2a2f303ae9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 220 deletions

View File

@ -46,8 +46,8 @@ def _has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('Dataloader returned 0 length. Please make sure'
' that your Dataloader atleast returns 1 batch')
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
return True
except TypeError:
return False
@ -186,10 +186,10 @@ class TrainerDataLoadingMixin(ABC):
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
'checking validation every k training batches.')
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')
@ -240,9 +240,9 @@ class TrainerDataLoadingMixin(ABC):
num_batches = int(num_batches * percent_check)
elif percent_check not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, '
f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,'
f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
return num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
@ -252,7 +252,7 @@ class TrainerDataLoadingMixin(ABC):
model: The current `LightningModule`
"""
if self.is_overriden('validation_step'):
self.num_val_batches, self.val_dataloaders =\
self.num_val_batches, self.val_dataloaders = \
self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from tests.base.eval_model_utils import CustomInfDataloader
class TestDataloaderVariations(ABC):
@ -10,5 +12,11 @@ class TestDataloaderVariations(ABC):
def test_dataloader(self):
return self.dataloader(train=False)
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))
def test_dataloader__empty(self):
return None
def test_dataloader__multiple(self):
return [self.dataloader(train=False), self.dataloader(train=False)]

View File

@ -37,3 +37,39 @@ class TestEpochEndVariations(ABC):
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
return result
def test_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
# if returned a scalar from test_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
test_loss_mean = 0
test_acc_mean = 0
i = 0
for dl_output in outputs:
for output in dl_output:
test_loss = output['test_loss']
# reduce manually when using dp
if self.trainer.use_dp:
test_loss = torch.mean(test_loss)
test_loss_mean += test_loss
# reduce manually when using dp
test_acc = output['test_acc']
if self.trainer.use_dp:
test_acc = torch.mean(test_acc)
test_acc_mean += test_acc
i += 1
test_loss_mean /= i
test_acc_mean /= i
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': tqdm_dict}
return result

View File

@ -8,6 +8,7 @@ class TestStepVariations(ABC):
"""
Houses all variations of test steps
"""
def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Default, baseline test_step
@ -87,3 +88,6 @@ class TestStepVariations(ABC):
f'test_acc_{dataloader_idx}': test_acc,
})
return output
def test_step__empty(self, batch, batch_idx, *args, **kwargs):
return {}

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from tests.base.eval_model_utils import CustomInfDataloader
class TrainDataloaderVariations(ABC):
@ -9,3 +11,12 @@ class TrainDataloaderVariations(ABC):
def train_dataloader(self):
return self.dataloader(train=True)
def train_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=True))
def train_dataloader__zero_length(self):
dataloader = self.dataloader(train=True)
dataloader.dataset.data = dataloader.dataset.data[:0]
dataloader.dataset.targets = dataloader.dataset.targets[:0]
return dataloader

View File

@ -26,3 +26,25 @@ class ModelTemplateUtils:
else: # if it is 2level deep -> per dataloader and per batch
val = sum(out[name] for out in output) / len(output)
return val
class CustomInfDataloader:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count >= 50:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from tests.base.eval_model_utils import CustomInfDataloader
class ValDataloaderVariations(ABC):
@ -13,3 +15,6 @@ class ValDataloaderVariations(ABC):
def val_dataloader__multiple(self):
return [self.dataloader(train=False),
self.dataloader(train=False)]
def val_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

View File

@ -8,22 +8,7 @@ from torch.utils.data.dataset import Subset
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
TestModelBase,
LightningTestModel,
LightEmptyTestStep,
LightValidationMultipleDataloadersMixin,
LightTestMultipleDataloadersMixin,
LightTestFitSingleTestDataloadersMixin,
LightTestFitMultipleTestDataloadersMixin,
LightValStepFitMultipleDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
LightZeroLenDataloader
)
from tests.base import EvalModelTemplate
@pytest.mark.parametrize("dataloader_options", [
@ -34,14 +19,7 @@ from tests.base import (
])
def test_dataloader_config_errors(tmpdir, dataloader_options):
class CurrentTestModel(
LightTrainDataloader,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# fit model
trainer = Trainer(
@ -57,15 +35,9 @@ def test_dataloader_config_errors(tmpdir, dataloader_options):
def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
class CurrentTestModel(
LightTrainDataloader,
LightValidationMultipleDataloadersMixin,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
model.val_dataloader = model.val_dataloader__multiple
model.validation_step = model.validation_step__multiple_dataloaders
# fit model
trainer = Trainer(
@ -91,16 +63,9 @@ def test_multiple_val_dataloader(tmpdir):
def test_multiple_test_dataloader(tmpdir):
"""Verify multiple test_dataloader."""
class CurrentTestModel(
LightTrainDataloader,
LightTestMultipleDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
model.test_dataloader = model.test_dataloader__multiple
model.test_step = model.test_step__multiple_dataloaders
# fit model
trainer = Trainer(
@ -127,20 +92,16 @@ def test_multiple_test_dataloader(tmpdir):
def test_train_dataloader_passed_to_fit(tmpdir):
"""Verify that train dataloader can be passed to fit """
class CurrentTestModel(LightTrainDataloader, TestModelBase):
pass
hparams = tutils.get_default_hparams()
# only train passed to fit
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
result = trainer.fit(model, train_dataloader=model._dataloader(train=True))
fit_options = dict(train_dataloader=model.dataloader(train=True))
result = trainer.fit(model, **fit_options)
assert result == 1
@ -148,26 +109,18 @@ def test_train_dataloader_passed_to_fit(tmpdir):
def test_train_val_dataloaders_passed_to_fit(tmpdir):
""" Verify that train & val dataloader can be passed to fit """
class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
# train, val passed to fit
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
result = trainer.fit(model,
train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False))
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
result = trainer.fit(model, **fit_options)
assert result == 1
assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
@ -176,31 +129,21 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
def test_all_dataloaders_passed_to_fit(tmpdir):
"""Verify train, val & test dataloader(s) can be passed to fit and test method"""
class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = EvalModelTemplate(tutils.get_default_hparams())
# train, val and test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
test_options = dict(test_dataloaders=model.dataloader(train=False))
result = trainer.fit(model,
train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False))
trainer.test(test_dataloaders=model._dataloader(train=False))
result = trainer.fit(model, **fit_options)
trainer.test(**test_options)
assert result == 1
assert len(trainer.val_dataloaders) == 1, \
@ -212,32 +155,25 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
def test_multiple_dataloaders_passed_to_fit(tmpdir):
"""Verify that multiple val & test dataloaders can be passed to fit."""
class CurrentTestModel(
LightningTestModel,
LightValStepFitMultipleDataloadersMixin,
LightTestFitMultipleTestDataloadersMixin,
):
pass
hparams = tutils.get_default_hparams()
model = EvalModelTemplate(tutils.get_default_hparams())
model.validation_step = model.validation_step__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
# train, multiple val and multiple test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])
test_options = dict(test_dataloaders=[model.dataloader(train=False),
model.dataloader(train=False)])
results = trainer.fit(
model,
train_dataloader=model._dataloader(train=True),
val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)],
)
assert results
trainer.test(test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)])
trainer.fit(model, **fit_options)
trainer.test(**test_options)
assert len(trainer.val_dataloaders) == 2, \
f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
@ -248,16 +184,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
def test_mixing_of_dataloader_options(tmpdir):
"""Verify that dataloaders can be passed to fit"""
class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
trainer_options = dict(
default_root_dir=tmpdir,
@ -268,17 +195,14 @@ def test_mixing_of_dataloader_options(tmpdir):
# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloaders=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert results
# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))
_ = trainer.fit(model, **fit_options)
trainer.test(**test_options)
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert results
trainer.test(test_dataloaders=model.dataloader(train=False))
assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
@ -286,72 +210,68 @@ def test_mixing_of_dataloader_options(tmpdir):
f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_inf_train_dataloader(tmpdir):
def test_train_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate(tutils.get_default_hparams())
model.train_dataloader = model.train_dataloader__infinite
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
def test_val_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate(tutils.get_default_hparams())
model.val_dataloader = model.val_dataloader__infinite
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5)
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
def test_test_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate(tutils.get_default_hparams())
model.test_dataloader = model.test_dataloader__infinite
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5)
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.test(model)
@pytest.mark.parametrize('check_interval', [50, 1.0])
def test_inf_train_dataloader(tmpdir, check_interval):
"""Test inf train data loader (e.g. IterableDataset)"""
class CurrentTestModel(
LightInfTrainDataloader,
LightningTestModel
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)
model = EvalModelTemplate(tutils.get_default_hparams())
model.train_dataloader = model.train_dataloader__infinite
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=50
train_check_interval=check_interval,
)
result = trainer.fit(model)
# verify training completed
assert result == 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
# verify training completed
assert result == 1
def test_inf_val_dataloader(tmpdir):
@pytest.mark.parametrize('check_interval', [1.0])
def test_inf_val_dataloader(tmpdir, check_interval):
"""Test inf val data loader (e.g. IterableDataset)"""
class CurrentTestModel(
LightInfValDataloader,
LightningTestModel
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.5
)
trainer.fit(model)
model = EvalModelTemplate(tutils.get_default_hparams())
model.val_dataloader = model.val_dataloader__infinite
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
max_epochs=1,
val_check_interval=check_interval,
)
result = trainer.fit(model)
@ -359,35 +279,20 @@ def test_inf_val_dataloader(tmpdir):
assert result == 1
def test_inf_test_dataloader(tmpdir):
@pytest.mark.parametrize('check_interval', [50, 1.0])
def test_inf_test_dataloader(tmpdir, check_interval):
"""Test inf test data loader (e.g. IterableDataset)"""
class CurrentTestModel(
LightInfTestDataloader,
LightningTestModel,
LightTestFitSingleTestDataloadersMixin
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
test_percent_check=0.5
)
trainer.test(model)
model = EvalModelTemplate(tutils.get_default_hparams())
model.test_dataloader = model.test_dataloader__infinite
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
max_epochs=1,
test_check_interval=check_interval,
)
result = trainer.fit(model)
trainer.test(model)
# verify training completed
assert result == 1
@ -396,14 +301,8 @@ def test_inf_test_dataloader(tmpdir):
def test_error_on_zero_len_dataloader(tmpdir):
""" Test that error is raised if a zero-length dataloader is defined """
class CurrentTestModel(
LightZeroLenDataloader,
LightningTestModel
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
model.train_dataloader = model.train_dataloader__zero_length
# fit model
with pytest.raises(ValueError):
@ -419,29 +318,22 @@ def test_error_on_zero_len_dataloader(tmpdir):
def test_warning_with_few_workers(tmpdir):
""" Test that error is raised if dataloader with only a few workers is used """
class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass
model = EvalModelTemplate(tutils.get_default_hparams())
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))
trainer = Trainer(
# logger file to get meta
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
test_options = dict(test_dataloaders=model.dataloader(train=False))
trainer = Trainer(**trainer_options)
# fit model
with pytest.warns(UserWarning, match='train'):
trainer.fit(model, **fit_options)
@ -491,10 +383,7 @@ def test_batch_size_smaller_than_num_gpus():
num_gpus = 3
batch_size = 3
class CurrentTestModel(
LightTrainDataloader,
TestModelBase,
):
class CurrentTestModel(EvalModelTemplate):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)