training forward refactor (#3134)
This commit is contained in:
parent
0b3cb3c955
commit
4db1a2a323
|
@ -1007,13 +1007,14 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# FORWARD (TRAINING STEP + TRAIN STEP END)
|
||||
# ---------------------------
|
||||
with self.profiler.profile('model_forward'):
|
||||
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||
with torch.cuda.amp.autocast():
|
||||
training_step_output = self.training_forward(split_batch, batch_idx,
|
||||
opt_idx, hiddens)
|
||||
training_step_output = self.accelerator_backend.training_step(args)
|
||||
training_step_output = self.call_hook('training_step_end', training_step_output)
|
||||
else:
|
||||
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx,
|
||||
hiddens)
|
||||
training_step_output = self.accelerator_backend.training_step(args)
|
||||
training_step_output = self.call_hook('training_step_end', training_step_output)
|
||||
|
||||
# ----------------------------
|
||||
# PROCESS THE RESULT
|
||||
|
@ -1186,26 +1187,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
return args
|
||||
|
||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
:param batch:
|
||||
:param batch_idx:
|
||||
:return:
|
||||
"""
|
||||
# ---------------
|
||||
# FORWARD
|
||||
# ---------------
|
||||
args = self.build_train_args(batch, batch_idx, opt_idx, hiddens)
|
||||
|
||||
# distributed forward
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# Training step end
|
||||
output = self.call_hook('training_step_end', output)
|
||||
|
||||
return output
|
||||
|
||||
def update_learning_rates(self, interval: str, monitor_metrics=None):
|
||||
"""Update learning rates.
|
||||
|
||||
|
|
Loading…
Reference in New Issue