diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dc0fef9079..92725874df 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1007,13 +1007,14 @@ class TrainerTrainLoopMixin(ABC): # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): + args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) if self.amp_backend == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): - training_step_output = self.training_forward(split_batch, batch_idx, - opt_idx, hiddens) + training_step_output = self.accelerator_backend.training_step(args) + training_step_output = self.call_hook('training_step_end', training_step_output) else: - training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, - hiddens) + training_step_output = self.accelerator_backend.training_step(args) + training_step_output = self.call_hook('training_step_end', training_step_output) # ---------------------------- # PROCESS THE RESULT @@ -1186,26 +1187,6 @@ class TrainerTrainLoopMixin(ABC): return args - def training_forward(self, batch, batch_idx, opt_idx, hiddens): - """ - Handle forward for each training case (distributed, single gpu, etc...) - :param batch: - :param batch_idx: - :return: - """ - # --------------- - # FORWARD - # --------------- - args = self.build_train_args(batch, batch_idx, opt_idx, hiddens) - - # distributed forward - output = self.accelerator_backend.training_step(args) - - # Training step end - output = self.call_hook('training_step_end', output) - - return output - def update_learning_rates(self, interval: str, monitor_metrics=None): """Update learning rates.