From 164751c918da973817543295dce9a3261942a591 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 17:09:14 -0400 Subject: [PATCH] added test for no dist sampler --- pytorch_lightning/models/trainer.py | 4 ++-- tests/test_models.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index a0c1c2aaac..4f4d012de2 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -369,7 +369,7 @@ class Trainer(TrainerIO): return val_results - def __get_dataloaders(self, model): + def get_dataloaders(self, model): """ Dataloaders are provided by the model :param model: @@ -591,7 +591,7 @@ class Trainer(TrainerIO): ref_model.on_gpu = self.on_gpu # transfer data loaders from model - self.__get_dataloaders(ref_model) + self.get_dataloaders(ref_model) # init training constants self.__layout_bookeeping() diff --git a/tests/test_models.py b/tests/test_models.py index bd6606e4bc..21bee021f1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -37,7 +37,7 @@ def test_ddp_sampler_error(): hparams = get_hparams() model = LightningTestModel(hparams, force_remove_distributed_sampler=True) - trainer_options = dict( + trainer = Trainer( progress_bar=False, max_nb_epochs=1, gpus=[0, 1], @@ -46,7 +46,7 @@ def test_ddp_sampler_error(): ) with pytest.raises(MisconfigurationException): - run_gpu_model_test(trainer_options, model, hparams) + trainer.get_dataloaders(model) def test_cpu_model():