added test for no dist sampler

This commit is contained in:
William Falcon 2019-07-24 17:09:14 -04:00
parent 096132b389
commit 164751c918
2 changed files with 4 additions and 4 deletions

View File

@ -369,7 +369,7 @@ class Trainer(TrainerIO):
return val_results
def __get_dataloaders(self, model):
def get_dataloaders(self, model):
"""
Dataloaders are provided by the model
:param model:
@ -591,7 +591,7 @@ class Trainer(TrainerIO):
ref_model.on_gpu = self.on_gpu
# transfer data loaders from model
self.__get_dataloaders(ref_model)
self.get_dataloaders(ref_model)
# init training constants
self.__layout_bookeeping()

View File

@ -37,7 +37,7 @@ def test_ddp_sampler_error():
hparams = get_hparams()
model = LightningTestModel(hparams, force_remove_distributed_sampler=True)
trainer_options = dict(
trainer = Trainer(
progress_bar=False,
max_nb_epochs=1,
gpus=[0, 1],
@ -46,7 +46,7 @@ def test_ddp_sampler_error():
)
with pytest.raises(MisconfigurationException):
run_gpu_model_test(trainer_options, model, hparams)
trainer.get_dataloaders(model)
def test_cpu_model():