diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8b9a20c8cf..9ea5c0a919 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -181,7 +181,7 @@ class Trainer(TrainerIO): self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 - self.__configure_slurm_ddp(self.data_parallel_device_ids, nb_gpu_nodes) + self.__configure_slurm_ddp(nb_gpu_nodes) # nvidia setup self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) @@ -284,51 +284,59 @@ class Trainer(TrainerIO): raise MisconfigurationException(m) def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes): - # make DP and DDP mutually exclusive - # single GPU will also use DP with devices=[0] - requested_gpus = self.data_parallel_device_ids is not None + # skip for CPU + if self.num_gpus == 0: + return - num_gpus = self.num_gpus - if num_gpus > 0: - # single GPU case - if num_gpus == 1: - self.single_gpu = True + # single GPU case + if self.num_gpus == 1: + self.single_gpu = True - elif num_gpus > 1 and distributed_backend is not None: - # DP, DDP case + if distributed_backend is not None: self.use_dp = distributed_backend == 'dp' self.use_ddp = distributed_backend == 'ddp' - # use ddp automatically if nb_gpu_nodes > 1 - if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover - self.use_ddp = True - self.use_dp = False - w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ - 'Switching to DistributedDataParallel for you. ' \ - 'To silence this warning set distributed_backend=ddp' - warnings.warn(w) + # multiple GPU case + elif self.num_gpus > 1: + if distributed_backend is not None: + # DP, DDP case + self.use_dp = distributed_backend == 'dp' + self.use_ddp = distributed_backend == 'ddp' elif distributed_backend is None: m = 'When using multiple GPUs set ' \ 'Trainer(distributed_backend=dp) (or ddp)' raise MisconfigurationException(m) + # use ddp automatically if nb_gpu_nodes > 1 + if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover + self.use_ddp = True + self.use_dp = False + w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ + 'Switching to DistributedDataParallel for you. ' \ + 'To silence this warning set distributed_backend=ddp' + warnings.warn(w) + print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) - def __configure_slurm_ddp(self, gpu_ids, nb_gpu_nodes): + def __configure_slurm_ddp(self, nb_gpu_nodes): self.is_slurm_managing_tasks = False - nb_gpus = len(gpu_ids) if type(gpu_ids) is list else gpu_ids - # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes if self.use_ddp: - self.nb_requested_gpus = nb_gpus * nb_gpu_nodes + self.nb_requested_gpus = self.num_gpus * nb_gpu_nodes self.nb_slurm_tasks = 0 try: self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS']) self.is_slurm_managing_tasks = self.nb_slurm_tasks == self.nb_requested_gpus + + # in interactive mode we don't manage tasks + job_name = os.environ['SLURM_JOB_NAME'] + if job_name == 'bash': + self.is_slurm_managing_tasks = False + except Exception: # likely not on slurm, so set the slurm managed flag to false self.is_slurm_managing_tasks = False diff --git a/tests/test_models.py b/tests/test_models.py index eefc70ba39..d0214ea321 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -586,13 +586,42 @@ def test_amp_single_gpu(): show_progress_bar=True, max_nb_epochs=1, gpus=1, - distributed_backend='dp', + distributed_backend='ddp', use_amp=True ) run_gpu_model_test(trainer_options, model, hparams) +def test_no_amp_single_gpu(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run.' + 'Rerun on a node with 2+ GPUs to run this test') + return + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + show_progress_bar=True, + max_nb_epochs=1, + gpus=1, + distributed_backend='dp', + use_amp=True + ) + + with pytest.raises((MisconfigurationException, ModuleNotFoundError)): + run_gpu_model_test(trainer_options, model, hparams) + + def test_cpu_restore_training(): """ Verify continue training session on CPU diff --git a/tox.ini b/tox.ini index f0ae2f5456..007b0a4783 100644 --- a/tox.ini +++ b/tox.ini @@ -34,8 +34,8 @@ deps = commands = check-manifest --ignore tox.ini python setup.py check -m -s - coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v --doctest-modules flake8 . + coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v --doctest-modules [flake8] exclude = .tox,*.egg,build,temp,examples/*