training forward refactor (#3134)

This commit is contained in:
William Falcon 2020-08-24 19:31:31 -04:00 committed by GitHub
parent 0b3cb3c955
commit 4db1a2a323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 24 deletions

View File

@ -1007,13 +1007,14 @@ class TrainerTrainLoopMixin(ABC):
# FORWARD (TRAINING STEP + TRAIN STEP END) # FORWARD (TRAINING STEP + TRAIN STEP END)
# --------------------------- # ---------------------------
with self.profiler.profile('model_forward'): with self.profiler.profile('model_forward'):
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
if self.amp_backend == AMPType.NATIVE and not self.use_tpu: if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx, training_step_output = self.accelerator_backend.training_step(args)
opt_idx, hiddens) training_step_output = self.call_hook('training_step_end', training_step_output)
else: else:
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, training_step_output = self.accelerator_backend.training_step(args)
hiddens) training_step_output = self.call_hook('training_step_end', training_step_output)
# ---------------------------- # ----------------------------
# PROCESS THE RESULT # PROCESS THE RESULT
@ -1186,26 +1187,6 @@ class TrainerTrainLoopMixin(ABC):
return args return args
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)
:param batch:
:param batch_idx:
:return:
"""
# ---------------
# FORWARD
# ---------------
args = self.build_train_args(batch, batch_idx, opt_idx, hiddens)
# distributed forward
output = self.accelerator_backend.training_step(args)
# Training step end
output = self.call_hook('training_step_end', output)
return output
def update_learning_rates(self, interval: str, monitor_metrics=None): def update_learning_rates(self, interval: str, monitor_metrics=None):
"""Update learning rates. """Update learning rates.