From 10d190e04552702ace047c06693a6e57765f3c59 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 8 Sep 2019 15:36:58 -0400 Subject: [PATCH] Simplified gpu api. No NVIDIA flag managing by lightning for cluster (#213) * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added simple cluster template * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs --- docs/Trainer/Distributed training.md | 42 ++++++--- pytorch_lightning/trainer/trainer.py | 136 +++++++++++++++++++-------- tests/test_models.py | 29 +++++- 3 files changed, 152 insertions(+), 55 deletions(-) diff --git a/docs/Trainer/Distributed training.md b/docs/Trainer/Distributed training.md index 1f24c3c908..16a9ac3ae9 100644 --- a/docs/Trainer/Distributed training.md +++ b/docs/Trainer/Distributed training.md @@ -10,10 +10,13 @@ For multi-node training you must use DistributedDataParallel. You can toggle between each mode by setting this flag. ``` {.python} -# DEFAULT uses DataParallel +# DEFAULT (when using single GPU or no GPUs) +trainer = Trainer(distributed_backend=None) + +# Change to DataParallel (gpus > 1) trainer = Trainer(distributed_backend='dp') -# change to distributed data parallel +# change to distributed data parallel (gpus > 1) trainer = Trainer(distributed_backend='ddp') ``` @@ -32,12 +35,24 @@ Below are the possible configurations we support. | 1 GPU | 1+ GPUs | DP | DDP | 16-bit | command | |---|---|---|---|---|---| -| Y | | | | | ```Trainer(gpus=[0])``` | -| Y | | | | Y | ```Trainer(gpus=[0], use_amp=True)``` | -| | Y | Y | | | ```Trainer(gpus=[0, ...])``` | -| | Y | | Y | | ```Trainer(gpus=[0, ...], distributed_backend='ddp')``` | -| | Y | | Y | Y | ```Trainer(gpus=[0, ...], distributed_backend='ddp', use_amp=True)``` | +| Y | | | | | ```Trainer(gpus=1)``` | +| Y | | | | Y | ```Trainer(gpus=1, use_amp=True)``` | +| | Y | Y | | | ```Trainer(gpus=k)``` | +| | Y | | Y | | ```Trainer(gpus=k, distributed_backend='ddp')``` | +| | Y | | Y | Y | ```Trainer(gpus=k, distributed_backend='ddp', use_amp=True)``` | +You also have the option of specifying which GPUs to use by passing a list: + +```python +# DEFAULT (int) +Trainer(gpus=k) + +# You specify which GPUs (don't use if running on cluster) +Trainer(gpus=[0, 1]) + +# can also be a string +Trainer(gpus='0, 1') +``` --- #### CUDA flags @@ -49,6 +64,9 @@ Lightning sets these for you automatically, there's NO NEED to do this yourself. # os.environ["CUDA_VISIBLE_DEVICES"] = "0" ``` +However, when using a cluster, Lightning will NOT set these flags (and you should not either). +SLURM will set these for you. + --- #### 16-bit mixed precision 16 bit precision can cut your memory footprint by half. If using volta architecture GPUs it can give a dramatic training speed-up as well. @@ -70,7 +88,7 @@ trainer = Trainer(amp_level='O2', use_amp=False) Make sure you're on a GPU machine. ```python # DEFAULT -trainer = Trainer(gpus=[0]) +trainer = Trainer(gpus=1) ``` --- @@ -78,11 +96,11 @@ trainer = Trainer(gpus=[0]) Make sure you're on a GPU machine. You can set as many GPUs as you want. In this setting, the model will run on all 8 GPUs at once using DataParallel under the hood. ```python -# to use DataParallel (default) -trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='dp') +# to use DataParallel +trainer = Trainer(gpus=8, distributed_backend='dp') # RECOMMENDED use DistributedDataParallel -trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp') +trainer = Trainer(gpus=8, distributed_backend='ddp') ``` --- @@ -90,7 +108,7 @@ trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp') Multi-node training is easily done by specifying these flags. ```python # train on 12*8 GPUs -trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], nb_gpu_nodes=12) +trainer = Trainer(gpus=8, nb_gpu_nodes=12) ``` In addition, make sure to set up your SLURM job correctly via the [SlurmClusterObject](https://williamfalcon.github.io/test-tube/hpc/SlurmCluster/). In particular, specify the number of tasks per node correctly. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 70a413680a..8b9a20c8cf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -76,7 +76,7 @@ class Trainer(TrainerIO): val_check_interval=1.0, log_save_interval=100, add_log_row_interval=10, - distributed_backend='dp', + distributed_backend=None, use_amp=False, print_nan_grads=False, print_weights_summary=True, @@ -91,7 +91,7 @@ class Trainer(TrainerIO): :param gradient_clip: int. 0 means don't clip. :param process_position: shown in the tqdm bar :param nb_gpu_nodes: number of GPU nodes - :param gpus: list or string of gpu ids [0, 1] or '0,1' + :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1' :param log_gpu_memory: Bool. If true, adds memory logs :param show_progress_bar: Bool. If true shows tqdm bar :param overfit_pct: float. uses this much of all datasets @@ -168,7 +168,7 @@ class Trainer(TrainerIO): # accumulated grads self.__configure_accumulated_gradients(accumulate_grad_batches) - # allow string and gpu list + # allow int, string and gpu list self.data_parallel_device_ids = self.__parse_gpu_ids(gpus) # distributed backend choice @@ -181,7 +181,10 @@ class Trainer(TrainerIO): self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 - self.__configure_slurm_ddp() + self.__configure_slurm_ddp(self.data_parallel_device_ids, nb_gpu_nodes) + + # nvidia setup + self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process # means the prog_bar won't survive pickling @@ -218,7 +221,8 @@ class Trainer(TrainerIO): self.weights_save_path = self.checkpoint_callback.filepath # if weights_save_path is still none here, set to current workingdir - self.weights_save_path = os.getcwd() + if self.weights_save_path is None: + self.weights_save_path = os.getcwd() def __init_amp(self, use_amp): self.use_amp = use_amp and APEX_AVAILABLE @@ -245,7 +249,10 @@ class Trainer(TrainerIO): raise TypeError("Gradient accumulation supports only int and dict types") def __parse_gpu_ids(self, gpus): - # gpus come in as a string. + """ + :param gpus: Int, string or list of ids + :return: + """ # if gpus = -1 then use all available devices # otherwise, split the string using commas if gpus is not None: @@ -256,47 +263,68 @@ class Trainer(TrainerIO): gpus = list(range(0, torch.cuda.device_count())) else: gpus = [int(x.strip()) for x in gpus.split(',')] + elif type(gpus) is int: + gpus = gpus else: - raise Exception('gpus has to be a string or list of ids') - - # set the correct cuda visible devices (using pci order) - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in gpus]) - print('VISIBLE GPUS: %r' % os.environ["CUDA_VISIBLE_DEVICES"]) + raise Exception('gpus has to be a string, int or list of ints') return gpus + @property + def num_gpus(self): + gpus = self.data_parallel_device_ids + if gpus is None: + return 0 + if type(gpus) is list: + return len(gpus) + if type(gpus) is int: + return gpus + + m = 'gpus must be int, none or list of ints' + raise MisconfigurationException(m) + def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes): # make DP and DDP mutually exclusive # single GPU will also use DP with devices=[0] requested_gpus = self.data_parallel_device_ids is not None - if requested_gpus and len(self.data_parallel_device_ids) > 0: - self.use_dp = distributed_backend == 'dp' - self.use_ddp = distributed_backend == 'ddp' - # use ddp automatically if nb_gpu_nodes > 1 - if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover - self.use_ddp = True - self.use_dp = False - w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ - 'Switching to DistributedDataParallel for you. ' \ - 'To silence this warning set distributed_backend=ddp' - warnings.warn(w) + num_gpus = self.num_gpus + if num_gpus > 0: + # single GPU case + if num_gpus == 1: + self.single_gpu = True - # remove dp and ddp when requesting single gpu - if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1: - self.use_ddp = False - self.use_dp = False - self.single_gpu = True + elif num_gpus > 1 and distributed_backend is not None: + # DP, DDP case + self.use_dp = distributed_backend == 'dp' + self.use_ddp = distributed_backend == 'ddp' + + # use ddp automatically if nb_gpu_nodes > 1 + if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover + self.use_ddp = True + self.use_dp = False + w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ + 'Switching to DistributedDataParallel for you. ' \ + 'To silence this warning set distributed_backend=ddp' + warnings.warn(w) + + elif distributed_backend is None: + m = 'When using multiple GPUs set ' \ + 'Trainer(distributed_backend=dp) (or ddp)' + raise MisconfigurationException(m) print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) - def __configure_slurm_ddp(self): + def __configure_slurm_ddp(self, gpu_ids, nb_gpu_nodes): + self.is_slurm_managing_tasks = False + + nb_gpus = len(gpu_ids) if type(gpu_ids) is list else gpu_ids + # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes if self.use_ddp: - self.nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes + self.nb_requested_gpus = nb_gpus * nb_gpu_nodes self.nb_slurm_tasks = 0 try: self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS']) @@ -305,6 +333,24 @@ class Trainer(TrainerIO): # likely not on slurm, so set the slurm managed flag to false self.is_slurm_managing_tasks = False + def __set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): + if data_parallel_device_ids is None: + return + + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + + # when slurm is managing the task it sets the visible devices + if not is_slurm_managing_tasks: + if type(data_parallel_device_ids) is int: + id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids))) + os.environ["CUDA_VISIBLE_DEVICES"] = id_str + else: + gpu_str = ','.join([str(x) for x in data_parallel_device_ids]) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str + + print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}') + @property def data_parallel(self): return self.use_dp or self.use_ddp @@ -412,8 +458,10 @@ class Trainer(TrainerIO): # CPU, single GPU if self.single_gpu: # for single GPU put inputs on gpu manually - gpu_id = self.data_parallel_device_ids[0] - data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id) + root_gpu = 0 + if type(self.data_parallel_device_ids) is list: + root_gpu = self.data_parallel_device_ids[0] + data_batch = self.transfer_batch_to_gpu(data_batch, root_gpu) args[0] = data_batch if test: @@ -598,7 +646,7 @@ class Trainer(TrainerIO): If you're not using SLURM, ignore this message! """ warnings.warn(msg) - mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) + mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, )) # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues @@ -645,7 +693,10 @@ class Trainer(TrainerIO): # allow for lr schedulers as well self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) - model.cuda(self.data_parallel_device_ids[0]) + root_gpu = 0 + if type(self.data_parallel_device_ids) is list: + root_gpu = self.data_parallel_device_ids[0] + model.cuda(root_gpu) if self.use_amp: # An example @@ -662,7 +713,10 @@ class Trainer(TrainerIO): # allow for lr schedulers as well self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) - model.cuda(self.data_parallel_device_ids[0]) + root_gpu = 0 + if type(self.data_parallel_device_ids) is list: + root_gpu = self.data_parallel_device_ids[0] + model.cuda(root_gpu) # check for this bug (amp + dp + !01 doesn't work) # https://github.com/NVIDIA/apex/issues/227 @@ -704,8 +758,8 @@ class Trainer(TrainerIO): self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0 # determine which process we are and world size - self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb - self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids) + self.proc_rank = self.node_rank * self.num_gpus + gpu_nb + self.world_size = self.nb_gpu_nodes * self.num_gpus # let the exp know the rank to avoid overwriting logs if self.experiment is not None: @@ -1045,7 +1099,9 @@ class Trainer(TrainerIO): elif self.use_dp: output = self.model(*args) elif self.single_gpu: - gpu_id = self.data_parallel_device_ids[0] + gpu_id = 0 + if type(self.data_parallel_device_ids) is list: + gpu_id = self.data_parallel_device_ids[0] data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id) args[0] = data_batch output = self.model.training_step(*args) @@ -1061,7 +1117,7 @@ class Trainer(TrainerIO): # reduce prog metrics for tqdm when using dp if self.use_dp: - nb_gpus = len(self.data_parallel_device_ids) + nb_gpus = self.num_gpus prog_output = reduce_distributed_output(prog_output, nb_gpus) model_specific_tqdm_metrics_dic = prog_output @@ -1081,7 +1137,7 @@ class Trainer(TrainerIO): # when using dp need to reduce the loss if self.use_dp: - loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids)) + loss = reduce_distributed_output(loss, self.num_gpus) return loss, model_specific_tqdm_metrics_dic diff --git a/tests/test_models.py b/tests/test_models.py index c94681a45d..eefc70ba39 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -585,7 +585,7 @@ def test_amp_single_gpu(): trainer_options = dict( show_progress_bar=True, max_nb_epochs=1, - gpus=[0], + gpus=1, distributed_backend='dp', use_amp=True ) @@ -675,7 +675,7 @@ def test_amp_gpu_ddp(): trainer_options = dict( show_progress_bar=True, max_nb_epochs=1, - gpus=[0, 1], + gpus=2, distributed_backend='ddp', use_amp=True ) @@ -1019,12 +1019,34 @@ def test_single_gpu_model(): max_nb_epochs=1, train_percent_check=0.1, val_percent_check=0.1, - gpus=[0] + gpus=1 ) run_gpu_model_test(trainer_options, model, hparams) +def test_multi_gpu_none_backend(): + """ + Make sure when using multiple GPUs the user can't use + distributed_backend = None + :return: + """ + if not can_run_gpu_test(): + return + + model, hparams = get_model() + trainer_options = dict( + show_progress_bar=False, + max_nb_epochs=1, + train_percent_check=0.1, + val_percent_check=0.1, + gpus='-1' + ) + + with pytest.raises(MisconfigurationException): + run_gpu_model_test(trainer_options, model, hparams) + + def test_multi_gpu_model_dp(): """ Make sure DP works @@ -1036,6 +1058,7 @@ def test_multi_gpu_model_dp(): model, hparams = get_model() trainer_options = dict( show_progress_bar=False, + distributed_backend='dp', max_nb_epochs=1, train_percent_check=0.1, val_percent_check=0.1,