training forward refactor (#3134)
This commit is contained in:
parent
0b3cb3c955
commit
4db1a2a323
|
@ -1007,13 +1007,14 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
# FORWARD (TRAINING STEP + TRAIN STEP END)
|
# FORWARD (TRAINING STEP + TRAIN STEP END)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
with self.profiler.profile('model_forward'):
|
with self.profiler.profile('model_forward'):
|
||||||
|
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
|
||||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
training_step_output = self.training_forward(split_batch, batch_idx,
|
training_step_output = self.accelerator_backend.training_step(args)
|
||||||
opt_idx, hiddens)
|
training_step_output = self.call_hook('training_step_end', training_step_output)
|
||||||
else:
|
else:
|
||||||
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx,
|
training_step_output = self.accelerator_backend.training_step(args)
|
||||||
hiddens)
|
training_step_output = self.call_hook('training_step_end', training_step_output)
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# PROCESS THE RESULT
|
# PROCESS THE RESULT
|
||||||
|
@ -1186,26 +1187,6 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
|
||||||
"""
|
|
||||||
Handle forward for each training case (distributed, single gpu, etc...)
|
|
||||||
:param batch:
|
|
||||||
:param batch_idx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# ---------------
|
|
||||||
# FORWARD
|
|
||||||
# ---------------
|
|
||||||
args = self.build_train_args(batch, batch_idx, opt_idx, hiddens)
|
|
||||||
|
|
||||||
# distributed forward
|
|
||||||
output = self.accelerator_backend.training_step(args)
|
|
||||||
|
|
||||||
# Training step end
|
|
||||||
output = self.call_hook('training_step_end', output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def update_learning_rates(self, interval: str, monitor_metrics=None):
|
def update_learning_rates(self, interval: str, monitor_metrics=None):
|
||||||
"""Update learning rates.
|
"""Update learning rates.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue