diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ede0e01396..7399539213 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1162,16 +1162,7 @@ class TrainerTrainLoopMixin(ABC): model.cpu() torch.cuda.empty_cache() - def training_forward(self, batch, batch_idx, opt_idx, hiddens): - """ - Handle forward for each training case (distributed, single gpu, etc...) - :param batch: - :param batch_idx: - :return: - """ - # --------------- - # FORWARD - # --------------- + def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] @@ -1189,25 +1180,22 @@ class TrainerTrainLoopMixin(ABC): if self.truncated_bptt_steps is not None: args.append(hiddens) + return args + + def training_forward(self, batch, batch_idx, opt_idx, hiddens): + """ + Handle forward for each training case (distributed, single gpu, etc...) + :param batch: + :param batch_idx: + :return: + """ + # --------------- + # FORWARD + # --------------- + args = self.build_train_args(batch, batch_idx, opt_idx, hiddens) + # distributed forward - if self.use_ddp or self.use_ddp2 or self.use_dp: - output = self.accelerator_backend.training_step(args) - - # horovod - elif self.use_horovod: - output = self.accelerator_backend.training_step(args) - - # single GPU forward - elif self.use_single_gpu: - output = self.accelerator_backend.training_step(args) - - # TPU support - elif self.use_tpu: - output = self.accelerator_backend.training_step(args) - - # CPU forward - else: - output = self.accelerator_backend.training_step(args) + output = self.accelerator_backend.training_step(args) is_result_obj = isinstance(output, Result)