Simplified gpu api. No NVIDIA flag managing by lightning for cluster (#213)

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added nvidia flag set

* added simple cluster template

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs

* sets correct backend for possible combinations of gpu inputs
This commit is contained in:
William Falcon 2019-09-08 15:36:58 -04:00 committed by GitHub
parent b3434943c7
commit 10d190e045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 55 deletions

View File

@ -10,10 +10,13 @@ For multi-node training you must use DistributedDataParallel.
You can toggle between each mode by setting this flag.
``` {.python}
# DEFAULT uses DataParallel
# DEFAULT (when using single GPU or no GPUs)
trainer = Trainer(distributed_backend=None)
# Change to DataParallel (gpus > 1)
trainer = Trainer(distributed_backend='dp')
# change to distributed data parallel
# change to distributed data parallel (gpus > 1)
trainer = Trainer(distributed_backend='ddp')
```
@ -32,12 +35,24 @@ Below are the possible configurations we support.
| 1 GPU | 1+ GPUs | DP | DDP | 16-bit | command |
|---|---|---|---|---|---|
| Y | | | | | ```Trainer(gpus=[0])``` |
| Y | | | | Y | ```Trainer(gpus=[0], use_amp=True)``` |
| | Y | Y | | | ```Trainer(gpus=[0, ...])``` |
| | Y | | Y | | ```Trainer(gpus=[0, ...], distributed_backend='ddp')``` |
| | Y | | Y | Y | ```Trainer(gpus=[0, ...], distributed_backend='ddp', use_amp=True)``` |
| Y | | | | | ```Trainer(gpus=1)``` |
| Y | | | | Y | ```Trainer(gpus=1, use_amp=True)``` |
| | Y | Y | | | ```Trainer(gpus=k)``` |
| | Y | | Y | | ```Trainer(gpus=k, distributed_backend='ddp')``` |
| | Y | | Y | Y | ```Trainer(gpus=k, distributed_backend='ddp', use_amp=True)``` |
You also have the option of specifying which GPUs to use by passing a list:
```python
# DEFAULT (int)
Trainer(gpus=k)
# You specify which GPUs (don't use if running on cluster)
Trainer(gpus=[0, 1])
# can also be a string
Trainer(gpus='0, 1')
```
---
#### CUDA flags
@ -49,6 +64,9 @@ Lightning sets these for you automatically, there's NO NEED to do this yourself.
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
```
However, when using a cluster, Lightning will NOT set these flags (and you should not either).
SLURM will set these for you.
---
#### 16-bit mixed precision
16 bit precision can cut your memory footprint by half. If using volta architecture GPUs it can give a dramatic training speed-up as well.
@ -70,7 +88,7 @@ trainer = Trainer(amp_level='O2', use_amp=False)
Make sure you're on a GPU machine.
```python
# DEFAULT
trainer = Trainer(gpus=[0])
trainer = Trainer(gpus=1)
```
---
@ -78,11 +96,11 @@ trainer = Trainer(gpus=[0])
Make sure you're on a GPU machine. You can set as many GPUs as you want.
In this setting, the model will run on all 8 GPUs at once using DataParallel under the hood.
```python
# to use DataParallel (default)
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='dp')
# to use DataParallel
trainer = Trainer(gpus=8, distributed_backend='dp')
# RECOMMENDED use DistributedDataParallel
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp')
trainer = Trainer(gpus=8, distributed_backend='ddp')
```
---
@ -90,7 +108,7 @@ trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp')
Multi-node training is easily done by specifying these flags.
```python
# train on 12*8 GPUs
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], nb_gpu_nodes=12)
trainer = Trainer(gpus=8, nb_gpu_nodes=12)
```
In addition, make sure to set up your SLURM job correctly via the [SlurmClusterObject](https://williamfalcon.github.io/test-tube/hpc/SlurmCluster/). In particular, specify the number of tasks per node correctly.

View File

@ -76,7 +76,7 @@ class Trainer(TrainerIO):
val_check_interval=1.0,
log_save_interval=100,
add_log_row_interval=10,
distributed_backend='dp',
distributed_backend=None,
use_amp=False,
print_nan_grads=False,
print_weights_summary=True,
@ -91,7 +91,7 @@ class Trainer(TrainerIO):
:param gradient_clip: int. 0 means don't clip.
:param process_position: shown in the tqdm bar
:param nb_gpu_nodes: number of GPU nodes
:param gpus: list or string of gpu ids [0, 1] or '0,1'
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
:param log_gpu_memory: Bool. If true, adds memory logs
:param show_progress_bar: Bool. If true shows tqdm bar
:param overfit_pct: float. uses this much of all datasets
@ -168,7 +168,7 @@ class Trainer(TrainerIO):
# accumulated grads
self.__configure_accumulated_gradients(accumulate_grad_batches)
# allow string and gpu list
# allow int, string and gpu list
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
# distributed backend choice
@ -181,7 +181,10 @@ class Trainer(TrainerIO):
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.__configure_slurm_ddp()
self.__configure_slurm_ddp(self.data_parallel_device_ids, nb_gpu_nodes)
# nvidia setup
self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
# can't init progress bar here because starting a new process
# means the prog_bar won't survive pickling
@ -218,7 +221,8 @@ class Trainer(TrainerIO):
self.weights_save_path = self.checkpoint_callback.filepath
# if weights_save_path is still none here, set to current workingdir
self.weights_save_path = os.getcwd()
if self.weights_save_path is None:
self.weights_save_path = os.getcwd()
def __init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
@ -245,7 +249,10 @@ class Trainer(TrainerIO):
raise TypeError("Gradient accumulation supports only int and dict types")
def __parse_gpu_ids(self, gpus):
# gpus come in as a string.
"""
:param gpus: Int, string or list of ids
:return:
"""
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
@ -256,47 +263,68 @@ class Trainer(TrainerIO):
gpus = list(range(0, torch.cuda.device_count()))
else:
gpus = [int(x.strip()) for x in gpus.split(',')]
elif type(gpus) is int:
gpus = gpus
else:
raise Exception('gpus has to be a string or list of ids')
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in gpus])
print('VISIBLE GPUS: %r' % os.environ["CUDA_VISIBLE_DEVICES"])
raise Exception('gpus has to be a string, int or list of ints')
return gpus
@property
def num_gpus(self):
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
if type(gpus) is list:
return len(gpus)
if type(gpus) is int:
return gpus
m = 'gpus must be int, none or list of ints'
raise MisconfigurationException(m)
def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes):
# make DP and DDP mutually exclusive
# single GPU will also use DP with devices=[0]
requested_gpus = self.data_parallel_device_ids is not None
if requested_gpus and len(self.data_parallel_device_ids) > 0:
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
# use ddp automatically if nb_gpu_nodes > 1
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
self.use_ddp = True
self.use_dp = False
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
'Switching to DistributedDataParallel for you. ' \
'To silence this warning set distributed_backend=ddp'
warnings.warn(w)
num_gpus = self.num_gpus
if num_gpus > 0:
# single GPU case
if num_gpus == 1:
self.single_gpu = True
# remove dp and ddp when requesting single gpu
if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1:
self.use_ddp = False
self.use_dp = False
self.single_gpu = True
elif num_gpus > 1 and distributed_backend is not None:
# DP, DDP case
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
# use ddp automatically if nb_gpu_nodes > 1
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
self.use_ddp = True
self.use_dp = False
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
'Switching to DistributedDataParallel for you. ' \
'To silence this warning set distributed_backend=ddp'
warnings.warn(w)
elif distributed_backend is None:
m = 'When using multiple GPUs set ' \
'Trainer(distributed_backend=dp) (or ddp)'
raise MisconfigurationException(m)
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
def __configure_slurm_ddp(self):
def __configure_slurm_ddp(self, gpu_ids, nb_gpu_nodes):
self.is_slurm_managing_tasks = False
nb_gpus = len(gpu_ids) if type(gpu_ids) is list else gpu_ids
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp:
self.nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes
self.nb_requested_gpus = nb_gpus * nb_gpu_nodes
self.nb_slurm_tasks = 0
try:
self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS'])
@ -305,6 +333,24 @@ class Trainer(TrainerIO):
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
def __set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
if data_parallel_device_ids is None:
return
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# when slurm is managing the task it sets the visible devices
if not is_slurm_managing_tasks:
if type(data_parallel_device_ids) is int:
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
else:
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
@property
def data_parallel(self):
return self.use_dp or self.use_ddp
@ -412,8 +458,10 @@ class Trainer(TrainerIO):
# CPU, single GPU
if self.single_gpu:
# for single GPU put inputs on gpu manually
gpu_id = self.data_parallel_device_ids[0]
data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id)
root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
data_batch = self.transfer_batch_to_gpu(data_batch, root_gpu)
args[0] = data_batch
if test:
@ -598,7 +646,7 @@ class Trainer(TrainerIO):
If you're not using SLURM, ignore this message!
"""
warnings.warn(msg)
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))
# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
@ -645,7 +693,10 @@ class Trainer(TrainerIO):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
model.cuda(self.data_parallel_device_ids[0])
root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
if self.use_amp:
# An example
@ -662,7 +713,10 @@ class Trainer(TrainerIO):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
model.cuda(self.data_parallel_device_ids[0])
root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
@ -704,8 +758,8 @@ class Trainer(TrainerIO):
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
# determine which process we are and world size
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
self.proc_rank = self.node_rank * self.num_gpus + gpu_nb
self.world_size = self.nb_gpu_nodes * self.num_gpus
# let the exp know the rank to avoid overwriting logs
if self.experiment is not None:
@ -1045,7 +1099,9 @@ class Trainer(TrainerIO):
elif self.use_dp:
output = self.model(*args)
elif self.single_gpu:
gpu_id = self.data_parallel_device_ids[0]
gpu_id = 0
if type(self.data_parallel_device_ids) is list:
gpu_id = self.data_parallel_device_ids[0]
data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id)
args[0] = data_batch
output = self.model.training_step(*args)
@ -1061,7 +1117,7 @@ class Trainer(TrainerIO):
# reduce prog metrics for tqdm when using dp
if self.use_dp:
nb_gpus = len(self.data_parallel_device_ids)
nb_gpus = self.num_gpus
prog_output = reduce_distributed_output(prog_output, nb_gpus)
model_specific_tqdm_metrics_dic = prog_output
@ -1081,7 +1137,7 @@ class Trainer(TrainerIO):
# when using dp need to reduce the loss
if self.use_dp:
loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids))
loss = reduce_distributed_output(loss, self.num_gpus)
return loss, model_specific_tqdm_metrics_dic

View File

@ -585,7 +585,7 @@ def test_amp_single_gpu():
trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=1,
gpus=[0],
gpus=1,
distributed_backend='dp',
use_amp=True
)
@ -675,7 +675,7 @@ def test_amp_gpu_ddp():
trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
gpus=2,
distributed_backend='ddp',
use_amp=True
)
@ -1019,12 +1019,34 @@ def test_single_gpu_model():
max_nb_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
gpus=[0]
gpus=1
)
run_gpu_model_test(trainer_options, model, hparams)
def test_multi_gpu_none_backend():
"""
Make sure when using multiple GPUs the user can't use
distributed_backend = None
:return:
"""
if not can_run_gpu_test():
return
model, hparams = get_model()
trainer_options = dict(
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
gpus='-1'
)
with pytest.raises(MisconfigurationException):
run_gpu_model_test(trainer_options, model, hparams)
def test_multi_gpu_model_dp():
"""
Make sure DP works
@ -1036,6 +1058,7 @@ def test_multi_gpu_model_dp():
model, hparams = get_model()
trainer_options = dict(
show_progress_bar=False,
distributed_backend='dp',
max_nb_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,