Simplified gpu api. No NVIDIA flag managing by lightning for cluster (#213)
* added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added nvidia flag set * added simple cluster template * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs
This commit is contained in:
parent
b3434943c7
commit
10d190e045
|
@ -10,10 +10,13 @@ For multi-node training you must use DistributedDataParallel.
|
|||
|
||||
You can toggle between each mode by setting this flag.
|
||||
``` {.python}
|
||||
# DEFAULT uses DataParallel
|
||||
# DEFAULT (when using single GPU or no GPUs)
|
||||
trainer = Trainer(distributed_backend=None)
|
||||
|
||||
# Change to DataParallel (gpus > 1)
|
||||
trainer = Trainer(distributed_backend='dp')
|
||||
|
||||
# change to distributed data parallel
|
||||
# change to distributed data parallel (gpus > 1)
|
||||
trainer = Trainer(distributed_backend='ddp')
|
||||
```
|
||||
|
||||
|
@ -32,12 +35,24 @@ Below are the possible configurations we support.
|
|||
|
||||
| 1 GPU | 1+ GPUs | DP | DDP | 16-bit | command |
|
||||
|---|---|---|---|---|---|
|
||||
| Y | | | | | ```Trainer(gpus=[0])``` |
|
||||
| Y | | | | Y | ```Trainer(gpus=[0], use_amp=True)``` |
|
||||
| | Y | Y | | | ```Trainer(gpus=[0, ...])``` |
|
||||
| | Y | | Y | | ```Trainer(gpus=[0, ...], distributed_backend='ddp')``` |
|
||||
| | Y | | Y | Y | ```Trainer(gpus=[0, ...], distributed_backend='ddp', use_amp=True)``` |
|
||||
| Y | | | | | ```Trainer(gpus=1)``` |
|
||||
| Y | | | | Y | ```Trainer(gpus=1, use_amp=True)``` |
|
||||
| | Y | Y | | | ```Trainer(gpus=k)``` |
|
||||
| | Y | | Y | | ```Trainer(gpus=k, distributed_backend='ddp')``` |
|
||||
| | Y | | Y | Y | ```Trainer(gpus=k, distributed_backend='ddp', use_amp=True)``` |
|
||||
|
||||
You also have the option of specifying which GPUs to use by passing a list:
|
||||
|
||||
```python
|
||||
# DEFAULT (int)
|
||||
Trainer(gpus=k)
|
||||
|
||||
# You specify which GPUs (don't use if running on cluster)
|
||||
Trainer(gpus=[0, 1])
|
||||
|
||||
# can also be a string
|
||||
Trainer(gpus='0, 1')
|
||||
```
|
||||
|
||||
---
|
||||
#### CUDA flags
|
||||
|
@ -49,6 +64,9 @@ Lightning sets these for you automatically, there's NO NEED to do this yourself.
|
|||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
```
|
||||
|
||||
However, when using a cluster, Lightning will NOT set these flags (and you should not either).
|
||||
SLURM will set these for you.
|
||||
|
||||
---
|
||||
#### 16-bit mixed precision
|
||||
16 bit precision can cut your memory footprint by half. If using volta architecture GPUs it can give a dramatic training speed-up as well.
|
||||
|
@ -70,7 +88,7 @@ trainer = Trainer(amp_level='O2', use_amp=False)
|
|||
Make sure you're on a GPU machine.
|
||||
```python
|
||||
# DEFAULT
|
||||
trainer = Trainer(gpus=[0])
|
||||
trainer = Trainer(gpus=1)
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -78,11 +96,11 @@ trainer = Trainer(gpus=[0])
|
|||
Make sure you're on a GPU machine. You can set as many GPUs as you want.
|
||||
In this setting, the model will run on all 8 GPUs at once using DataParallel under the hood.
|
||||
```python
|
||||
# to use DataParallel (default)
|
||||
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='dp')
|
||||
# to use DataParallel
|
||||
trainer = Trainer(gpus=8, distributed_backend='dp')
|
||||
|
||||
# RECOMMENDED use DistributedDataParallel
|
||||
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp')
|
||||
trainer = Trainer(gpus=8, distributed_backend='ddp')
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -90,7 +108,7 @@ trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='ddp')
|
|||
Multi-node training is easily done by specifying these flags.
|
||||
```python
|
||||
# train on 12*8 GPUs
|
||||
trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], nb_gpu_nodes=12)
|
||||
trainer = Trainer(gpus=8, nb_gpu_nodes=12)
|
||||
```
|
||||
|
||||
In addition, make sure to set up your SLURM job correctly via the [SlurmClusterObject](https://williamfalcon.github.io/test-tube/hpc/SlurmCluster/). In particular, specify the number of tasks per node correctly.
|
||||
|
|
|
@ -76,7 +76,7 @@ class Trainer(TrainerIO):
|
|||
val_check_interval=1.0,
|
||||
log_save_interval=100,
|
||||
add_log_row_interval=10,
|
||||
distributed_backend='dp',
|
||||
distributed_backend=None,
|
||||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
print_weights_summary=True,
|
||||
|
@ -91,7 +91,7 @@ class Trainer(TrainerIO):
|
|||
:param gradient_clip: int. 0 means don't clip.
|
||||
:param process_position: shown in the tqdm bar
|
||||
:param nb_gpu_nodes: number of GPU nodes
|
||||
:param gpus: list or string of gpu ids [0, 1] or '0,1'
|
||||
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
||||
:param log_gpu_memory: Bool. If true, adds memory logs
|
||||
:param show_progress_bar: Bool. If true shows tqdm bar
|
||||
:param overfit_pct: float. uses this much of all datasets
|
||||
|
@ -168,7 +168,7 @@ class Trainer(TrainerIO):
|
|||
# accumulated grads
|
||||
self.__configure_accumulated_gradients(accumulate_grad_batches)
|
||||
|
||||
# allow string and gpu list
|
||||
# allow int, string and gpu list
|
||||
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
|
||||
|
||||
# distributed backend choice
|
||||
|
@ -181,7 +181,10 @@ class Trainer(TrainerIO):
|
|||
self.proc_rank = 0
|
||||
self.world_size = 1
|
||||
self.node_rank = 0
|
||||
self.__configure_slurm_ddp()
|
||||
self.__configure_slurm_ddp(self.data_parallel_device_ids, nb_gpu_nodes)
|
||||
|
||||
# nvidia setup
|
||||
self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
||||
|
||||
# can't init progress bar here because starting a new process
|
||||
# means the prog_bar won't survive pickling
|
||||
|
@ -218,7 +221,8 @@ class Trainer(TrainerIO):
|
|||
self.weights_save_path = self.checkpoint_callback.filepath
|
||||
|
||||
# if weights_save_path is still none here, set to current workingdir
|
||||
self.weights_save_path = os.getcwd()
|
||||
if self.weights_save_path is None:
|
||||
self.weights_save_path = os.getcwd()
|
||||
|
||||
def __init_amp(self, use_amp):
|
||||
self.use_amp = use_amp and APEX_AVAILABLE
|
||||
|
@ -245,7 +249,10 @@ class Trainer(TrainerIO):
|
|||
raise TypeError("Gradient accumulation supports only int and dict types")
|
||||
|
||||
def __parse_gpu_ids(self, gpus):
|
||||
# gpus come in as a string.
|
||||
"""
|
||||
:param gpus: Int, string or list of ids
|
||||
:return:
|
||||
"""
|
||||
# if gpus = -1 then use all available devices
|
||||
# otherwise, split the string using commas
|
||||
if gpus is not None:
|
||||
|
@ -256,47 +263,68 @@ class Trainer(TrainerIO):
|
|||
gpus = list(range(0, torch.cuda.device_count()))
|
||||
else:
|
||||
gpus = [int(x.strip()) for x in gpus.split(',')]
|
||||
elif type(gpus) is int:
|
||||
gpus = gpus
|
||||
else:
|
||||
raise Exception('gpus has to be a string or list of ids')
|
||||
|
||||
# set the correct cuda visible devices (using pci order)
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in gpus])
|
||||
print('VISIBLE GPUS: %r' % os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
raise Exception('gpus has to be a string, int or list of ints')
|
||||
|
||||
return gpus
|
||||
|
||||
@property
|
||||
def num_gpus(self):
|
||||
gpus = self.data_parallel_device_ids
|
||||
if gpus is None:
|
||||
return 0
|
||||
if type(gpus) is list:
|
||||
return len(gpus)
|
||||
if type(gpus) is int:
|
||||
return gpus
|
||||
|
||||
m = 'gpus must be int, none or list of ints'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes):
|
||||
# make DP and DDP mutually exclusive
|
||||
# single GPU will also use DP with devices=[0]
|
||||
requested_gpus = self.data_parallel_device_ids is not None
|
||||
if requested_gpus and len(self.data_parallel_device_ids) > 0:
|
||||
self.use_dp = distributed_backend == 'dp'
|
||||
self.use_ddp = distributed_backend == 'ddp'
|
||||
|
||||
# use ddp automatically if nb_gpu_nodes > 1
|
||||
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
|
||||
self.use_ddp = True
|
||||
self.use_dp = False
|
||||
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
|
||||
'Switching to DistributedDataParallel for you. ' \
|
||||
'To silence this warning set distributed_backend=ddp'
|
||||
warnings.warn(w)
|
||||
num_gpus = self.num_gpus
|
||||
if num_gpus > 0:
|
||||
# single GPU case
|
||||
if num_gpus == 1:
|
||||
self.single_gpu = True
|
||||
|
||||
# remove dp and ddp when requesting single gpu
|
||||
if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1:
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
self.single_gpu = True
|
||||
elif num_gpus > 1 and distributed_backend is not None:
|
||||
# DP, DDP case
|
||||
self.use_dp = distributed_backend == 'dp'
|
||||
self.use_ddp = distributed_backend == 'ddp'
|
||||
|
||||
# use ddp automatically if nb_gpu_nodes > 1
|
||||
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
|
||||
self.use_ddp = True
|
||||
self.use_dp = False
|
||||
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
|
||||
'Switching to DistributedDataParallel for you. ' \
|
||||
'To silence this warning set distributed_backend=ddp'
|
||||
warnings.warn(w)
|
||||
|
||||
elif distributed_backend is None:
|
||||
m = 'When using multiple GPUs set ' \
|
||||
'Trainer(distributed_backend=dp) (or ddp)'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
||||
|
||||
def __configure_slurm_ddp(self):
|
||||
def __configure_slurm_ddp(self, gpu_ids, nb_gpu_nodes):
|
||||
self.is_slurm_managing_tasks = False
|
||||
|
||||
nb_gpus = len(gpu_ids) if type(gpu_ids) is list else gpu_ids
|
||||
|
||||
# extract SLURM flag vars
|
||||
# whenever we have the correct number of tasks, we let slurm manage processes
|
||||
# otherwise we launch the required number of processes
|
||||
if self.use_ddp:
|
||||
self.nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes
|
||||
self.nb_requested_gpus = nb_gpus * nb_gpu_nodes
|
||||
self.nb_slurm_tasks = 0
|
||||
try:
|
||||
self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS'])
|
||||
|
@ -305,6 +333,24 @@ class Trainer(TrainerIO):
|
|||
# likely not on slurm, so set the slurm managed flag to false
|
||||
self.is_slurm_managing_tasks = False
|
||||
|
||||
def __set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
|
||||
if data_parallel_device_ids is None:
|
||||
return
|
||||
|
||||
# set the correct cuda visible devices (using pci order)
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
|
||||
# when slurm is managing the task it sets the visible devices
|
||||
if not is_slurm_managing_tasks:
|
||||
if type(data_parallel_device_ids) is int:
|
||||
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
|
||||
else:
|
||||
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
|
||||
|
||||
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self.use_dp or self.use_ddp
|
||||
|
@ -412,8 +458,10 @@ class Trainer(TrainerIO):
|
|||
# CPU, single GPU
|
||||
if self.single_gpu:
|
||||
# for single GPU put inputs on gpu manually
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id)
|
||||
root_gpu = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
root_gpu = self.data_parallel_device_ids[0]
|
||||
data_batch = self.transfer_batch_to_gpu(data_batch, root_gpu)
|
||||
args[0] = data_batch
|
||||
|
||||
if test:
|
||||
|
@ -598,7 +646,7 @@ class Trainer(TrainerIO):
|
|||
If you're not using SLURM, ignore this message!
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
|
||||
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))
|
||||
|
||||
# 1 gpu or dp option triggers training using DP module
|
||||
# easier to avoid NCCL issues
|
||||
|
@ -645,7 +693,10 @@ class Trainer(TrainerIO):
|
|||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
root_gpu = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
root_gpu = self.data_parallel_device_ids[0]
|
||||
model.cuda(root_gpu)
|
||||
|
||||
if self.use_amp:
|
||||
# An example
|
||||
|
@ -662,7 +713,10 @@ class Trainer(TrainerIO):
|
|||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
root_gpu = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
root_gpu = self.data_parallel_device_ids[0]
|
||||
model.cuda(root_gpu)
|
||||
|
||||
# check for this bug (amp + dp + !01 doesn't work)
|
||||
# https://github.com/NVIDIA/apex/issues/227
|
||||
|
@ -704,8 +758,8 @@ class Trainer(TrainerIO):
|
|||
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
|
||||
# determine which process we are and world size
|
||||
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
|
||||
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
||||
self.proc_rank = self.node_rank * self.num_gpus + gpu_nb
|
||||
self.world_size = self.nb_gpu_nodes * self.num_gpus
|
||||
|
||||
# let the exp know the rank to avoid overwriting logs
|
||||
if self.experiment is not None:
|
||||
|
@ -1045,7 +1099,9 @@ class Trainer(TrainerIO):
|
|||
elif self.use_dp:
|
||||
output = self.model(*args)
|
||||
elif self.single_gpu:
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
gpu_id = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id)
|
||||
args[0] = data_batch
|
||||
output = self.model.training_step(*args)
|
||||
|
@ -1061,7 +1117,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# reduce prog metrics for tqdm when using dp
|
||||
if self.use_dp:
|
||||
nb_gpus = len(self.data_parallel_device_ids)
|
||||
nb_gpus = self.num_gpus
|
||||
prog_output = reduce_distributed_output(prog_output, nb_gpus)
|
||||
|
||||
model_specific_tqdm_metrics_dic = prog_output
|
||||
|
@ -1081,7 +1137,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# when using dp need to reduce the loss
|
||||
if self.use_dp:
|
||||
loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids))
|
||||
loss = reduce_distributed_output(loss, self.num_gpus)
|
||||
|
||||
return loss, model_specific_tqdm_metrics_dic
|
||||
|
||||
|
|
|
@ -585,7 +585,7 @@ def test_amp_single_gpu():
|
|||
trainer_options = dict(
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0],
|
||||
gpus=1,
|
||||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
|
@ -675,7 +675,7 @@ def test_amp_gpu_ddp():
|
|||
trainer_options = dict(
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
gpus=2,
|
||||
distributed_backend='ddp',
|
||||
use_amp=True
|
||||
)
|
||||
|
@ -1019,12 +1019,34 @@ def test_single_gpu_model():
|
|||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
gpus=[0]
|
||||
gpus=1
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_none_backend():
|
||||
"""
|
||||
Make sure when using multiple GPUs the user can't use
|
||||
distributed_backend = None
|
||||
:return:
|
||||
"""
|
||||
if not can_run_gpu_test():
|
||||
return
|
||||
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
gpus='-1'
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_dp():
|
||||
"""
|
||||
Make sure DP works
|
||||
|
@ -1036,6 +1058,7 @@ def test_multi_gpu_model_dp():
|
|||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
distributed_backend='dp',
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
|
|
Loading…
Reference in New Issue