added test for no dist sampler
This commit is contained in:
parent
096132b389
commit
164751c918
|
@ -369,7 +369,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
return val_results
|
||||
|
||||
def __get_dataloaders(self, model):
|
||||
def get_dataloaders(self, model):
|
||||
"""
|
||||
Dataloaders are provided by the model
|
||||
:param model:
|
||||
|
@ -591,7 +591,7 @@ class Trainer(TrainerIO):
|
|||
ref_model.on_gpu = self.on_gpu
|
||||
|
||||
# transfer data loaders from model
|
||||
self.__get_dataloaders(ref_model)
|
||||
self.get_dataloaders(ref_model)
|
||||
|
||||
# init training constants
|
||||
self.__layout_bookeeping()
|
||||
|
|
|
@ -37,7 +37,7 @@ def test_ddp_sampler_error():
|
|||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams, force_remove_distributed_sampler=True)
|
||||
|
||||
trainer_options = dict(
|
||||
trainer = Trainer(
|
||||
progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
|
@ -46,7 +46,7 @@ def test_ddp_sampler_error():
|
|||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
trainer.get_dataloaders(model)
|
||||
|
||||
|
||||
def test_cpu_model():
|
||||
|
|
Loading…
Reference in New Issue