added single gpu train
This commit is contained in:
parent
73b50abb57
commit
afa4548b12
|
@ -105,7 +105,7 @@ class Trainer(TrainerIO):
|
||||||
:param log_save_interval:
|
:param log_save_interval:
|
||||||
:param add_log_row_interval:
|
:param add_log_row_interval:
|
||||||
:param distributed_backend:
|
:param distributed_backend:
|
||||||
'np' to use DistributedParallel, 'dp' to use DistributedDataParallel
|
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
|
||||||
:param use_amp:
|
:param use_amp:
|
||||||
:param print_nan_grads:
|
:param print_nan_grads:
|
||||||
:param print_weights_summary:
|
:param print_weights_summary:
|
||||||
|
@ -147,6 +147,7 @@ class Trainer(TrainerIO):
|
||||||
self.node_rank = 0
|
self.node_rank = 0
|
||||||
self.use_ddp = False
|
self.use_ddp = False
|
||||||
self.use_dp = False
|
self.use_dp = False
|
||||||
|
self.single_gpu = False
|
||||||
|
|
||||||
# training bookeeping
|
# training bookeeping
|
||||||
self.total_batch_nb = 0
|
self.total_batch_nb = 0
|
||||||
|
@ -194,6 +195,12 @@ class Trainer(TrainerIO):
|
||||||
'To silence this warning set distributed_backend=ddp'
|
'To silence this warning set distributed_backend=ddp'
|
||||||
warnings.warn(w)
|
warnings.warn(w)
|
||||||
|
|
||||||
|
# remove dp and ddp when requesting single gpu
|
||||||
|
if len(self.data_parallel_device_ids) == 1:
|
||||||
|
self.use_ddp = False
|
||||||
|
self.use_dp = False
|
||||||
|
self.single_gpu = True
|
||||||
|
|
||||||
# extract SLURM flag vars
|
# extract SLURM flag vars
|
||||||
# whenever we have the correct number of tasks, we let slurm manage processes
|
# whenever we have the correct number of tasks, we let slurm manage processes
|
||||||
# otherwise we launch the required number of processes
|
# otherwise we launch the required number of processes
|
||||||
|
@ -463,6 +470,9 @@ If you're not using SLURM, ignore this message!
|
||||||
elif self.use_dp:
|
elif self.use_dp:
|
||||||
self.__dp_train(model)
|
self.__dp_train(model)
|
||||||
|
|
||||||
|
elif self.single_gpu:
|
||||||
|
self.__single_gpu_train(model)\
|
||||||
|
|
||||||
# ON CPU
|
# ON CPU
|
||||||
else:
|
else:
|
||||||
# run through amp wrapper
|
# run through amp wrapper
|
||||||
|
@ -482,6 +492,24 @@ If you're not using SLURM, ignore this message!
|
||||||
# used for testing or when we need to know that training succeeded
|
# used for testing or when we need to know that training succeeded
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
def __single_gpu_train(self, model):
|
||||||
|
# CHOOSE OPTIMIZER
|
||||||
|
# allow for lr schedulers as well
|
||||||
|
self.optimizers = model.configure_optimizers()
|
||||||
|
if len(self.optimizers) == 2:
|
||||||
|
self.optimizers, self.lr_schedulers = self.optimizers
|
||||||
|
|
||||||
|
model.cuda(self.data_parallel_device_ids[0])
|
||||||
|
|
||||||
|
if self.use_amp:
|
||||||
|
# An example
|
||||||
|
model, optimizers = amp.initialize(
|
||||||
|
model, self.optimizers, opt_level=self.amp_level,
|
||||||
|
)
|
||||||
|
self.optimizers = optimizers
|
||||||
|
|
||||||
|
self.__run_pretrain_routine(model)
|
||||||
|
|
||||||
def __dp_train(self, model):
|
def __dp_train(self, model):
|
||||||
|
|
||||||
# CHOOSE OPTIMIZER
|
# CHOOSE OPTIMIZER
|
||||||
|
|
Loading…
Reference in New Issue