From a7f3974f958bb011dcbc478468ed38c57a610616 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 6 Nov 2019 14:34:50 -0500 Subject: [PATCH] Release (#467) * smurf ethics * smurf ethics * removed auto ddp fix * removed auto ddp fix * removed auto ddp fix * removed auto ddp fix * removed auto ddp fix * removed auto ddp fix --- pytorch_lightning/trainer/ddp_mixin.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/ddp_mixin.py b/pytorch_lightning/trainer/ddp_mixin.py index 0653fa58d2..d5270fe7ae 100644 --- a/pytorch_lightning/trainer/ddp_mixin.py +++ b/pytorch_lightning/trainer/ddp_mixin.py @@ -49,14 +49,13 @@ class TrainerDDPMixin(object): 'Trainer(distributed_backend=dp) (or ddp)' raise MisconfigurationException(m) - # use ddp automatically if nb_gpu_nodes > 1 - if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover - self.use_ddp = True - self.use_dp = False + # throw error to force user ddp or ddp2 choice + if nb_gpu_nodes > 1 and not (self.use_ddp2 or self.use_ddp): # pragma: no cover w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ 'Switching to DistributedDataParallel for you. ' \ - 'To silence this warning set distributed_backend=ddp' - warnings.warn(w) + 'To silence this warning set distributed_backend=ddp' \ + 'or distributed_backend=ddp2' + raise MisconfigurationException(w) logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}') @@ -173,7 +172,7 @@ class TrainerDDPMixin(object): if self.distributed_backend == 'ddp': device_ids = [gpu_nb] elif self.use_ddp2: - device_ids = None + device_ids = self.data_parallel_device_ids # allow user to configure ddp model = model.configure_ddp(model, device_ids)