Tests: refactor trainer dataloaders (#1690)
* refactor default model * drop redundant seeds * refactor dataloaders tests * fix multiple * fix conf * flake8 * Apply suggestions from code review Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
a6de1b8d75
commit
2a2f303ae9
|
@ -46,8 +46,8 @@ def _has_len(dataloader: DataLoader) -> bool:
|
|||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError('Dataloader returned 0 length. Please make sure'
|
||||
' that your Dataloader atleast returns 1 batch')
|
||||
raise ValueError('`Dataloader` returned 0 length.'
|
||||
' Please make sure that your Dataloader at least returns 1 batch')
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
@ -186,10 +186,10 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.val_check_batch = float('inf')
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
|
||||
'DataLoader does not implement `__len__`) for `train_dataloader`, '
|
||||
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
|
||||
'checking validation every k training batches.')
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
|
||||
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
|
||||
' checking validation every k training batches.')
|
||||
else:
|
||||
self._percent_range_check('val_check_interval')
|
||||
|
||||
|
@ -240,9 +240,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
num_batches = int(num_batches * percent_check)
|
||||
elif percent_check not in (0.0, 1.0):
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
|
||||
f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, '
|
||||
f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,'
|
||||
f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
|
||||
return num_batches, dataloaders
|
||||
|
||||
def reset_val_dataloader(self, model: LightningModule) -> None:
|
||||
|
@ -252,7 +252,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
model: The current `LightningModule`
|
||||
"""
|
||||
if self.is_overriden('validation_step'):
|
||||
self.num_val_batches, self.val_dataloaders =\
|
||||
self.num_val_batches, self.val_dataloaders = \
|
||||
self._reset_eval_dataloader(model, 'val')
|
||||
|
||||
def reset_test_dataloader(self, model) -> None:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.base.eval_model_utils import CustomInfDataloader
|
||||
|
||||
|
||||
class TestDataloaderVariations(ABC):
|
||||
|
||||
|
@ -10,5 +12,11 @@ class TestDataloaderVariations(ABC):
|
|||
def test_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def test_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=False))
|
||||
|
||||
def test_dataloader__empty(self):
|
||||
return None
|
||||
|
||||
def test_dataloader__multiple(self):
|
||||
return [self.dataloader(train=False), self.dataloader(train=False)]
|
||||
|
|
|
@ -37,3 +37,39 @@ class TestEpochEndVariations(ABC):
|
|||
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
||||
return result
|
||||
|
||||
def test_epoch_end__multiple_dataloaders(self, outputs):
|
||||
"""
|
||||
Called at the end of validation to aggregate outputs
|
||||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
# if returned a scalar from test_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
test_loss_mean = 0
|
||||
test_acc_mean = 0
|
||||
i = 0
|
||||
for dl_output in outputs:
|
||||
for output in dl_output:
|
||||
test_loss = output['test_loss']
|
||||
|
||||
# reduce manually when using dp
|
||||
if self.trainer.use_dp:
|
||||
test_loss = torch.mean(test_loss)
|
||||
test_loss_mean += test_loss
|
||||
|
||||
# reduce manually when using dp
|
||||
test_acc = output['test_acc']
|
||||
if self.trainer.use_dp:
|
||||
test_acc = torch.mean(test_acc)
|
||||
|
||||
test_acc_mean += test_acc
|
||||
i += 1
|
||||
|
||||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
return result
|
||||
|
|
|
@ -8,6 +8,7 @@ class TestStepVariations(ABC):
|
|||
"""
|
||||
Houses all variations of test steps
|
||||
"""
|
||||
|
||||
def test_step(self, batch, batch_idx, *args, **kwargs):
|
||||
"""
|
||||
Default, baseline test_step
|
||||
|
@ -87,3 +88,6 @@ class TestStepVariations(ABC):
|
|||
f'test_acc_{dataloader_idx}': test_acc,
|
||||
})
|
||||
return output
|
||||
|
||||
def test_step__empty(self, batch, batch_idx, *args, **kwargs):
|
||||
return {}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.base.eval_model_utils import CustomInfDataloader
|
||||
|
||||
|
||||
class TrainDataloaderVariations(ABC):
|
||||
|
||||
|
@ -9,3 +11,12 @@ class TrainDataloaderVariations(ABC):
|
|||
|
||||
def train_dataloader(self):
|
||||
return self.dataloader(train=True)
|
||||
|
||||
def train_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=True))
|
||||
|
||||
def train_dataloader__zero_length(self):
|
||||
dataloader = self.dataloader(train=True)
|
||||
dataloader.dataset.data = dataloader.dataset.data[:0]
|
||||
dataloader.dataset.targets = dataloader.dataset.targets[:0]
|
||||
return dataloader
|
||||
|
|
|
@ -26,3 +26,25 @@ class ModelTemplateUtils:
|
|||
else: # if it is 2level deep -> per dataloader and per batch
|
||||
val = sum(out[name] for out in output) / len(output)
|
||||
return val
|
||||
|
||||
|
||||
class CustomInfDataloader:
|
||||
|
||||
def __init__(self, dataloader):
|
||||
self.dataloader = dataloader
|
||||
self.iter = iter(dataloader)
|
||||
self.count = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.count >= 50:
|
||||
raise StopIteration
|
||||
self.count = self.count + 1
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
self.iter = iter(self.dataloader)
|
||||
return next(self.iter)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.base.eval_model_utils import CustomInfDataloader
|
||||
|
||||
|
||||
class ValDataloaderVariations(ABC):
|
||||
|
||||
|
@ -13,3 +15,6 @@ class ValDataloaderVariations(ABC):
|
|||
def val_dataloader__multiple(self):
|
||||
return [self.dataloader(train=False),
|
||||
self.dataloader(train=False)]
|
||||
|
||||
def val_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=False))
|
||||
|
|
|
@ -8,22 +8,7 @@ from torch.utils.data.dataset import Subset
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
LightningTestModel,
|
||||
LightEmptyTestStep,
|
||||
LightValidationMultipleDataloadersMixin,
|
||||
LightTestMultipleDataloadersMixin,
|
||||
LightTestFitSingleTestDataloadersMixin,
|
||||
LightTestFitMultipleTestDataloadersMixin,
|
||||
LightValStepFitMultipleDataloadersMixin,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTrainDataloader,
|
||||
LightInfTrainDataloader,
|
||||
LightInfValDataloader,
|
||||
LightInfTestDataloader,
|
||||
LightZeroLenDataloader
|
||||
)
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataloader_options", [
|
||||
|
@ -34,14 +19,7 @@ from tests.base import (
|
|||
])
|
||||
def test_dataloader_config_errors(tmpdir, dataloader_options):
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
|
@ -57,15 +35,9 @@ def test_dataloader_config_errors(tmpdir, dataloader_options):
|
|||
def test_multiple_val_dataloader(tmpdir):
|
||||
"""Verify multiple val_dataloader."""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValidationMultipleDataloadersMixin,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.val_dataloader = model.val_dataloader__multiple
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
|
@ -91,16 +63,9 @@ def test_multiple_val_dataloader(tmpdir):
|
|||
def test_multiple_test_dataloader(tmpdir):
|
||||
"""Verify multiple test_dataloader."""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightTestMultipleDataloadersMixin,
|
||||
LightEmptyTestStep,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.test_dataloader = model.test_dataloader__multiple
|
||||
model.test_step = model.test_step__multiple_dataloaders
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
|
@ -127,20 +92,16 @@ def test_multiple_test_dataloader(tmpdir):
|
|||
def test_train_dataloader_passed_to_fit(tmpdir):
|
||||
"""Verify that train dataloader can be passed to fit """
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
|
||||
# only train passed to fit
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
result = trainer.fit(model, train_dataloader=model._dataloader(train=True))
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True))
|
||||
result = trainer.fit(model, **fit_options)
|
||||
|
||||
assert result == 1
|
||||
|
||||
|
@ -148,26 +109,18 @@ def test_train_dataloader_passed_to_fit(tmpdir):
|
|||
def test_train_val_dataloaders_passed_to_fit(tmpdir):
|
||||
""" Verify that train & val dataloader can be passed to fit """
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
|
||||
# train, val passed to fit
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
result = trainer.fit(model,
|
||||
train_dataloader=model._dataloader(train=True),
|
||||
val_dataloaders=model._dataloader(train=False))
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
val_dataloaders=model.dataloader(train=False))
|
||||
|
||||
result = trainer.fit(model, **fit_options)
|
||||
assert result == 1
|
||||
assert len(trainer.val_dataloaders) == 1, \
|
||||
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
||||
|
@ -176,31 +129,21 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
|
|||
def test_all_dataloaders_passed_to_fit(tmpdir):
|
||||
"""Verify train, val & test dataloader(s) can be passed to fit and test method"""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTestFitSingleTestDataloadersMixin,
|
||||
LightEmptyTestStep,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# train, val and test passed to fit
|
||||
model = CurrentTestModel(hparams)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
val_dataloaders=model.dataloader(train=False))
|
||||
test_options = dict(test_dataloaders=model.dataloader(train=False))
|
||||
|
||||
result = trainer.fit(model,
|
||||
train_dataloader=model._dataloader(train=True),
|
||||
val_dataloaders=model._dataloader(train=False))
|
||||
|
||||
trainer.test(test_dataloaders=model._dataloader(train=False))
|
||||
result = trainer.fit(model, **fit_options)
|
||||
trainer.test(**test_options)
|
||||
|
||||
assert result == 1
|
||||
assert len(trainer.val_dataloaders) == 1, \
|
||||
|
@ -212,32 +155,25 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
|
|||
def test_multiple_dataloaders_passed_to_fit(tmpdir):
|
||||
"""Verify that multiple val & test dataloaders can be passed to fit."""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightningTestModel,
|
||||
LightValStepFitMultipleDataloadersMixin,
|
||||
LightTestFitMultipleTestDataloadersMixin,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
model.test_step = model.test_step__multiple_dataloaders
|
||||
|
||||
# train, multiple val and multiple test passed to fit
|
||||
model = CurrentTestModel(hparams)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
val_dataloaders=[model.dataloader(train=False),
|
||||
model.dataloader(train=False)])
|
||||
test_options = dict(test_dataloaders=[model.dataloader(train=False),
|
||||
model.dataloader(train=False)])
|
||||
|
||||
results = trainer.fit(
|
||||
model,
|
||||
train_dataloader=model._dataloader(train=True),
|
||||
val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)],
|
||||
)
|
||||
assert results
|
||||
|
||||
trainer.test(test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)])
|
||||
trainer.fit(model, **fit_options)
|
||||
trainer.test(**test_options)
|
||||
|
||||
assert len(trainer.val_dataloaders) == 2, \
|
||||
f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
||||
|
@ -248,16 +184,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
|
|||
def test_mixing_of_dataloader_options(tmpdir):
|
||||
"""Verify that dataloaders can be passed to fit"""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTestFitSingleTestDataloadersMixin,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -268,17 +195,14 @@ def test_mixing_of_dataloader_options(tmpdir):
|
|||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
fit_options = dict(val_dataloaders=model._dataloader(train=False))
|
||||
results = trainer.fit(model, **fit_options)
|
||||
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
||||
assert results
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
fit_options = dict(val_dataloaders=model._dataloader(train=False))
|
||||
test_options = dict(test_dataloaders=model._dataloader(train=False))
|
||||
|
||||
_ = trainer.fit(model, **fit_options)
|
||||
trainer.test(**test_options)
|
||||
results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
|
||||
assert results
|
||||
trainer.test(test_dataloaders=model.dataloader(train=False))
|
||||
|
||||
assert len(trainer.val_dataloaders) == 1, \
|
||||
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
||||
|
@ -286,72 +210,68 @@ def test_mixing_of_dataloader_options(tmpdir):
|
|||
f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
||||
|
||||
|
||||
def test_inf_train_dataloader(tmpdir):
|
||||
def test_train_inf_dataloader_error(tmpdir):
|
||||
"""Test inf train data loader (e.g. IterableDataset)"""
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.train_dataloader = model.train_dataloader__infinite
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_val_inf_dataloader_error(tmpdir):
|
||||
"""Test inf train data loader (e.g. IterableDataset)"""
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.val_dataloader = model.val_dataloader__infinite
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_test_inf_dataloader_error(tmpdir):
|
||||
"""Test inf train data loader (e.g. IterableDataset)"""
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.test_dataloader = model.test_dataloader__infinite
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('check_interval', [50, 1.0])
|
||||
def test_inf_train_dataloader(tmpdir, check_interval):
|
||||
"""Test inf train data loader (e.g. IterableDataset)"""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightInfTrainDataloader,
|
||||
LightningTestModel
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# fit model
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=0.5
|
||||
)
|
||||
trainer.fit(model)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.train_dataloader = model.train_dataloader__infinite
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=50
|
||||
train_check_interval=check_interval,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_inf_val_dataloader(tmpdir):
|
||||
@pytest.mark.parametrize('check_interval', [1.0])
|
||||
def test_inf_val_dataloader(tmpdir, check_interval):
|
||||
"""Test inf val data loader (e.g. IterableDataset)"""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightInfValDataloader,
|
||||
LightningTestModel
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# fit model
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.5
|
||||
)
|
||||
trainer.fit(model)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.val_dataloader = model.val_dataloader__infinite
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1
|
||||
max_epochs=1,
|
||||
val_check_interval=check_interval,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
|
@ -359,35 +279,20 @@ def test_inf_val_dataloader(tmpdir):
|
|||
assert result == 1
|
||||
|
||||
|
||||
def test_inf_test_dataloader(tmpdir):
|
||||
@pytest.mark.parametrize('check_interval', [50, 1.0])
|
||||
def test_inf_test_dataloader(tmpdir, check_interval):
|
||||
"""Test inf test data loader (e.g. IterableDataset)"""
|
||||
|
||||
class CurrentTestModel(
|
||||
LightInfTestDataloader,
|
||||
LightningTestModel,
|
||||
LightTestFitSingleTestDataloadersMixin
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# fit model
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
test_percent_check=0.5
|
||||
)
|
||||
trainer.test(model)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.test_dataloader = model.test_dataloader__infinite
|
||||
|
||||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1
|
||||
max_epochs=1,
|
||||
test_check_interval=check_interval,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
@ -396,14 +301,8 @@ def test_inf_test_dataloader(tmpdir):
|
|||
def test_error_on_zero_len_dataloader(tmpdir):
|
||||
""" Test that error is raised if a zero-length dataloader is defined """
|
||||
|
||||
class CurrentTestModel(
|
||||
LightZeroLenDataloader,
|
||||
LightningTestModel
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
model.train_dataloader = model.train_dataloader__zero_length
|
||||
|
||||
# fit model
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -419,29 +318,22 @@ def test_error_on_zero_len_dataloader(tmpdir):
|
|||
def test_warning_with_few_workers(tmpdir):
|
||||
""" Test that error is raised if dataloader with only a few workers is used """
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTestFitSingleTestDataloadersMixin,
|
||||
LightEmptyTestStep,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
||||
val_dataloaders=model._dataloader(train=False))
|
||||
test_options = dict(test_dataloaders=model._dataloader(train=False))
|
||||
|
||||
trainer = Trainer(
|
||||
# logger file to get meta
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
val_dataloaders=model.dataloader(train=False))
|
||||
test_options = dict(test_dataloaders=model.dataloader(train=False))
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
||||
# fit model
|
||||
with pytest.warns(UserWarning, match='train'):
|
||||
trainer.fit(model, **fit_options)
|
||||
|
@ -491,10 +383,7 @@ def test_batch_size_smaller_than_num_gpus():
|
|||
num_gpus = 3
|
||||
batch_size = 3
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
TestModelBase,
|
||||
):
|
||||
class CurrentTestModel(EvalModelTemplate):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue