Update trainer.py
This commit is contained in:
parent
a83588b14e
commit
a41abad5b2
|
@ -277,7 +277,7 @@ class Trainer(TrainerIO):
|
||||||
self.test_dataloader = model.test_dataloader
|
self.test_dataloader = model.test_dataloader
|
||||||
self.val_dataloader = model.val_dataloader
|
self.val_dataloader = model.val_dataloader
|
||||||
|
|
||||||
if self.data_parallel and not issubclass(self.tng_dataloader.sampler, DistributedSampler):
|
if self.data_parallel and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
|
||||||
msg = '''
|
msg = '''
|
||||||
when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler).
|
when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler).
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue