added single gpu train
This commit is contained in:
parent
73b50abb57
commit
afa4548b12
|
@ -105,7 +105,7 @@ class Trainer(TrainerIO):
|
|||
:param log_save_interval:
|
||||
:param add_log_row_interval:
|
||||
:param distributed_backend:
|
||||
'np' to use DistributedParallel, 'dp' to use DistributedDataParallel
|
||||
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
|
||||
:param use_amp:
|
||||
:param print_nan_grads:
|
||||
:param print_weights_summary:
|
||||
|
@ -147,6 +147,7 @@ class Trainer(TrainerIO):
|
|||
self.node_rank = 0
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
self.single_gpu = False
|
||||
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
|
@ -194,6 +195,12 @@ class Trainer(TrainerIO):
|
|||
'To silence this warning set distributed_backend=ddp'
|
||||
warnings.warn(w)
|
||||
|
||||
# remove dp and ddp when requesting single gpu
|
||||
if len(self.data_parallel_device_ids) == 1:
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
self.single_gpu = True
|
||||
|
||||
# extract SLURM flag vars
|
||||
# whenever we have the correct number of tasks, we let slurm manage processes
|
||||
# otherwise we launch the required number of processes
|
||||
|
@ -463,6 +470,9 @@ If you're not using SLURM, ignore this message!
|
|||
elif self.use_dp:
|
||||
self.__dp_train(model)
|
||||
|
||||
elif self.single_gpu:
|
||||
self.__single_gpu_train(model)\
|
||||
|
||||
# ON CPU
|
||||
else:
|
||||
# run through amp wrapper
|
||||
|
@ -482,6 +492,24 @@ If you're not using SLURM, ignore this message!
|
|||
# used for testing or when we need to know that training succeeded
|
||||
return 1
|
||||
|
||||
def __single_gpu_train(self, model):
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers = model.configure_optimizers()
|
||||
if len(self.optimizers) == 2:
|
||||
self.optimizers, self.lr_schedulers = self.optimizers
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
if self.use_amp:
|
||||
# An example
|
||||
model, optimizers = amp.initialize(
|
||||
model, self.optimizers, opt_level=self.amp_level,
|
||||
)
|
||||
self.optimizers = optimizers
|
||||
|
||||
self.__run_pretrain_routine(model)
|
||||
|
||||
def __dp_train(self, model):
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
|
|
Loading…
Reference in New Issue