This commit is contained in:
William Falcon 2019-10-05 16:56:24 -04:00
parent a59f351ef8
commit ef98931d18
1 changed files with 10 additions and 9 deletions

View File

@ -639,7 +639,8 @@ class Trainer(TrainerIO):
# call warnings from proc zero only which triggers dataloaders
# if those have to download data it will only happen on proc 0
if self.proc_rank == 0:
if self.use_ddp or self.use_ddp2 and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
@ -658,14 +659,14 @@ class Trainer(TrainerIO):
"""
warnings.warn(msg)
if self.use_ddp or self.use_ddp2 and self.get_val_dataloaders is not None:
if on_ddp and self.get_val_dataloaders is not None:
for dataloader in self.get_val_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your val_dataloader(s) don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
You're using multiple gpus and multiple nodes without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
@ -681,14 +682,14 @@ class Trainer(TrainerIO):
warnings.warn(msg)
break
if self.use_ddp or self.use_ddp2 and self.get_test_dataloaders is not None:
if on_ddp and self.get_test_dataloaders is not None:
for dataloader in self.get_test_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your test_dataloader(s) don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
You're using multiple gpus and multiple nodes without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()