2020-02-27 01:31:40 +00:00
|
|
|
import pytest
|
2020-04-03 21:55:08 +00:00
|
|
|
import torch
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
import tests.base.utils as tutils
|
2020-02-27 01:31:40 +00:00
|
|
|
from pytorch_lightning import Trainer
|
2020-03-31 12:57:48 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2020-03-25 11:46:27 +00:00
|
|
|
from tests.base import (
|
2020-02-27 01:31:40 +00:00
|
|
|
TestModelBase,
|
|
|
|
LightningTestModel,
|
|
|
|
LightEmptyTestStep,
|
|
|
|
LightValidationMultipleDataloadersMixin,
|
|
|
|
LightTestMultipleDataloadersMixin,
|
|
|
|
LightTestFitSingleTestDataloadersMixin,
|
|
|
|
LightTestFitMultipleTestDataloadersMixin,
|
|
|
|
LightValStepFitMultipleDataloadersMixin,
|
|
|
|
LightValStepFitSingleDataloaderMixin,
|
|
|
|
LightTrainDataloader,
|
2020-04-05 15:07:16 +00:00
|
|
|
LightValidationDataloader,
|
2020-03-12 16:46:02 +00:00
|
|
|
LightInfTrainDataloader,
|
|
|
|
LightInfValDataloader,
|
2020-03-30 22:43:53 +00:00
|
|
|
LightInfTestDataloader,
|
|
|
|
LightZeroLenDataloader
|
2020-02-27 01:31:40 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
def test_dataloader_config_errors(tmpdir):
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-03-12 16:46:02 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# percent check < 0
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
train_percent_check=-0.1,
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# percent check > 1
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
train_percent_check=1.1,
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# int val_check_interval > num batches
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_check_interval=10000
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# float val_check_interval > 1
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_check_interval=1.1
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
2020-02-27 01:31:40 +00:00
|
|
|
def test_multiple_val_dataloader(tmpdir):
|
|
|
|
"""Verify multiple val_dataloader."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightValidationMultipleDataloadersMixin,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=1.0,
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
# verify training completed
|
|
|
|
assert result == 1
|
|
|
|
|
|
|
|
# verify there are 2 val loaders
|
|
|
|
assert len(trainer.val_dataloaders) == 2, \
|
|
|
|
'Multiple val_dataloaders not initiated properly'
|
|
|
|
|
|
|
|
# make sure predictions are good for each val set
|
|
|
|
for dataloader in trainer.val_dataloaders:
|
|
|
|
tutils.run_prediction(dataloader, trainer.model)
|
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_test_dataloader(tmpdir):
|
|
|
|
"""Verify multiple test_dataloader."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightTestMultipleDataloadersMixin,
|
|
|
|
LightEmptyTestStep,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
trainer.fit(model)
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
# verify there are 2 val loaders
|
|
|
|
assert len(trainer.test_dataloaders) == 2, \
|
|
|
|
'Multiple test_dataloaders not initiated properly'
|
|
|
|
|
|
|
|
# make sure predictions are good for each test set
|
|
|
|
for dataloader in trainer.test_dataloaders:
|
|
|
|
tutils.run_prediction(dataloader, trainer.model)
|
|
|
|
|
|
|
|
# run the test method
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
|
|
|
|
def test_train_dataloaders_passed_to_fit(tmpdir):
|
2020-03-12 16:46:02 +00:00
|
|
|
"""Verify that train dataloader can be passed to fit """
|
2020-02-27 01:31:40 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# only train passed to fit
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(train_dataloader=model._dataloader(train=True))
|
2020-03-12 16:46:02 +00:00
|
|
|
result = trainer.fit(model, **fit_options)
|
|
|
|
|
|
|
|
assert result == 1
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_train_val_dataloaders_passed_to_fit(tmpdir):
|
|
|
|
""" Verify that train & val dataloader can be passed to fit """
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightValStepFitSingleDataloaderMixin,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# train, val passed to fit
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
|
|
|
val_dataloaders=model._dataloader(train=False))
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
result = trainer.fit(model, **fit_options)
|
|
|
|
assert result == 1
|
2020-02-27 01:31:40 +00:00
|
|
|
assert len(trainer.val_dataloaders) == 1, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_all_dataloaders_passed_to_fit(tmpdir):
|
2020-03-12 16:46:02 +00:00
|
|
|
"""Verify train, val & test dataloader can be passed to fit """
|
2020-02-27 01:31:40 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightValStepFitSingleDataloaderMixin,
|
|
|
|
LightTestFitSingleTestDataloadersMixin,
|
|
|
|
LightEmptyTestStep,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# train, val and test passed to fit
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
|
|
|
val_dataloaders=model._dataloader(train=False),
|
|
|
|
test_dataloaders=model._dataloader(train=False))
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
result = trainer.fit(model, **fit_options)
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
trainer.test()
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
assert result == 1
|
2020-02-27 01:31:40 +00:00
|
|
|
assert len(trainer.val_dataloaders) == 1, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
assert len(trainer.test_dataloaders) == 1, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_dataloaders_passed_to_fit(tmpdir):
|
|
|
|
"""Verify that multiple val & test dataloaders can be passed to fit."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightningTestModel,
|
|
|
|
LightValStepFitMultipleDataloadersMixin,
|
|
|
|
LightTestFitMultipleTestDataloadersMixin,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# train, multiple val and multiple test passed to fit
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
|
|
|
val_dataloaders=[model._dataloader(train=False),
|
|
|
|
model._dataloader(train=False)],
|
|
|
|
test_dataloaders=[model._dataloader(train=False),
|
|
|
|
model._dataloader(train=False)])
|
|
|
|
results = trainer.fit(model, **fit_options)
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
assert len(trainer.val_dataloaders) == 2, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
assert len(trainer.test_dataloaders) == 2, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_mixing_of_dataloader_options(tmpdir):
|
|
|
|
"""Verify that dataloaders can be passed to fit"""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightValStepFitSingleDataloaderMixin,
|
|
|
|
LightTestFitSingleTestDataloadersMixin,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(val_dataloaders=model._dataloader(train=False))
|
|
|
|
results = trainer.fit(model, **fit_options)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
fit_options = dict(val_dataloaders=model._dataloader(train=False),
|
|
|
|
test_dataloaders=model._dataloader(train=False))
|
|
|
|
_ = trainer.fit(model, **fit_options)
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
assert len(trainer.val_dataloaders) == 1, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
assert len(trainer.test_dataloaders) == 1, \
|
2020-03-12 16:46:02 +00:00
|
|
|
f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_inf_train_dataloader(tmpdir):
|
|
|
|
"""Test inf train data loader (e.g. IterableDataset)"""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
class CurrentTestModel(
|
|
|
|
LightInfTrainDataloader,
|
|
|
|
LightningTestModel
|
|
|
|
):
|
|
|
|
pass
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-03-12 16:46:02 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
# fit model
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_check_interval=0.5
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_check_interval=50
|
|
|
|
)
|
|
|
|
result = trainer.fit(model)
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
# verify training completed
|
|
|
|
assert result == 1
|
|
|
|
|
2020-03-29 19:27:44 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1
|
|
|
|
)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
# verify training completed
|
|
|
|
assert result == 1
|
|
|
|
|
2020-03-12 16:46:02 +00:00
|
|
|
|
|
|
|
def test_inf_val_dataloader(tmpdir):
|
|
|
|
"""Test inf val data loader (e.g. IterableDataset)"""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightInfValDataloader,
|
|
|
|
LightningTestModel
|
|
|
|
):
|
|
|
|
pass
|
2020-02-27 01:31:40 +00:00
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-02-27 01:31:40 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
2020-03-12 16:46:02 +00:00
|
|
|
val_percent_check=0.5
|
2020-02-27 01:31:40 +00:00
|
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
2020-03-12 16:46:02 +00:00
|
|
|
max_epochs=1
|
|
|
|
)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
# verify training completed
|
|
|
|
assert result == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_inf_test_dataloader(tmpdir):
|
|
|
|
"""Test inf test data loader (e.g. IterableDataset)"""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightInfTestDataloader,
|
|
|
|
LightningTestModel,
|
|
|
|
LightTestFitSingleTestDataloadersMixin
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-03-12 16:46:02 +00:00
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
with pytest.raises(MisconfigurationException):
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
test_percent_check=0.5
|
|
|
|
)
|
|
|
|
trainer.test(model)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1
|
2020-02-27 01:31:40 +00:00
|
|
|
)
|
|
|
|
result = trainer.fit(model)
|
2020-03-12 16:46:02 +00:00
|
|
|
trainer.test(model)
|
2020-02-27 01:31:40 +00:00
|
|
|
|
|
|
|
# verify training completed
|
|
|
|
assert result == 1
|
2020-03-30 22:43:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_error_on_zero_len_dataloader(tmpdir):
|
|
|
|
""" Test that error is raised if a zero-length dataloader is defined """
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightZeroLenDataloader,
|
|
|
|
LightningTestModel
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
|
|
|
hparams = tutils.get_default_hparams()
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
trainer = Trainer(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
test_percent_check=0.5
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|
2020-04-03 21:55:08 +00:00
|
|
|
|
|
|
|
|
2020-04-05 15:07:16 +00:00
|
|
|
def test_warning_with_few_workers(tmpdir):
|
|
|
|
""" Test that error is raised if dataloader with only a few workers is used """
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class CurrentTestModel(
|
|
|
|
LightTrainDataloader,
|
|
|
|
LightValStepFitSingleDataloaderMixin,
|
|
|
|
LightTestFitSingleTestDataloadersMixin,
|
|
|
|
LightEmptyTestStep,
|
|
|
|
TestModelBase,
|
|
|
|
):
|
|
|
|
pass
|
|
|
|
|
|
|
|
hparams = tutils.get_default_hparams()
|
|
|
|
model = CurrentTestModel(hparams)
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
val_percent_check=0.1,
|
|
|
|
train_percent_check=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
|
|
|
val_dataloaders=model._dataloader(train=False),
|
|
|
|
test_dataloaders=model._dataloader(train=False))
|
|
|
|
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
with pytest.warns(UserWarning, match='train'):
|
|
|
|
trainer.fit(model, **fit_options)
|
|
|
|
|
|
|
|
with pytest.warns(UserWarning, match='val'):
|
|
|
|
trainer.fit(model, **fit_options)
|
|
|
|
|
|
|
|
with pytest.warns(UserWarning, match='test'):
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
|
2020-04-03 21:55:08 +00:00
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
|
|
|
|
def test_dataloader_reinit_for_subclass():
|
|
|
|
|
|
|
|
class CustomDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
|
|
|
batch_sampler=None, num_workers=0, collate_fn=None,
|
|
|
|
pin_memory=False, drop_last=False, timeout=0,
|
|
|
|
worker_init_fn=None, dummy_kwarg=None):
|
|
|
|
super().__init__(dataset,
|
|
|
|
batch_size,
|
|
|
|
shuffle,
|
|
|
|
sampler,
|
|
|
|
batch_sampler,
|
|
|
|
num_workers,
|
|
|
|
collate_fn,
|
|
|
|
pin_memory,
|
|
|
|
drop_last,
|
|
|
|
timeout,
|
|
|
|
worker_init_fn)
|
|
|
|
|
|
|
|
self.dummy_kwarg = dummy_kwarg
|
|
|
|
|
|
|
|
trainer = Trainer(gpus=[0, 1],
|
|
|
|
num_nodes=1,
|
|
|
|
distributed_backend='ddp')
|
|
|
|
|
|
|
|
class CustomDummyObj:
|
|
|
|
sampler = None
|
|
|
|
|
|
|
|
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
|
|
|
|
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
|
|
|
|
|
|
|
|
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True)
|
|
|
|
assert isinstance(result, torch.utils.data.DataLoader)
|
|
|
|
assert isinstance(result, CustomDataLoader)
|
|
|
|
assert hasattr(result, 'dummy_kwarg')
|