diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 2dbf46eb14..b0ce3d1a19 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -105,7 +105,7 @@ class Trainer(TrainerIO): :param log_save_interval: :param add_log_row_interval: :param distributed_backend: - 'np' to use DistributedParallel, 'dp' to use DistributedDataParallel + 'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none :param use_amp: :param print_nan_grads: :param print_weights_summary: @@ -147,6 +147,7 @@ class Trainer(TrainerIO): self.node_rank = 0 self.use_ddp = False self.use_dp = False + self.single_gpu = False # training bookeeping self.total_batch_nb = 0 @@ -194,6 +195,12 @@ class Trainer(TrainerIO): 'To silence this warning set distributed_backend=ddp' warnings.warn(w) + # remove dp and ddp when requesting single gpu + if len(self.data_parallel_device_ids) == 1: + self.use_ddp = False + self.use_dp = False + self.single_gpu = True + # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes @@ -463,6 +470,9 @@ If you're not using SLURM, ignore this message! elif self.use_dp: self.__dp_train(model) + elif self.single_gpu: + self.__single_gpu_train(model)\ + # ON CPU else: # run through amp wrapper @@ -482,6 +492,24 @@ If you're not using SLURM, ignore this message! # used for testing or when we need to know that training succeeded return 1 + def __single_gpu_train(self, model): + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers + + model.cuda(self.data_parallel_device_ids[0]) + + if self.use_amp: + # An example + model, optimizers = amp.initialize( + model, self.optimizers, opt_level=self.amp_level, + ) + self.optimizers = optimizers + + self.__run_pretrain_routine(model) + def __dp_train(self, model): # CHOOSE OPTIMIZER