ref: reduced all simplified_forward (#3126)

* simplified training_forward

* simplified training_forward

* simplified training_forward
This commit is contained in:
William Falcon 2020-08-24 13:05:58 -04:00 committed by GitHub
parent 6068b29d29
commit 20018b2668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 28 deletions

View File

@ -1162,16 +1162,7 @@ class TrainerTrainLoopMixin(ABC):
model.cpu()
torch.cuda.empty_cache()
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)
:param batch:
:param batch_idx:
:return:
"""
# ---------------
# FORWARD
# ---------------
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]
@ -1189,25 +1180,22 @@ class TrainerTrainLoopMixin(ABC):
if self.truncated_bptt_steps is not None:
args.append(hiddens)
return args
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)
:param batch:
:param batch_idx:
:return:
"""
# ---------------
# FORWARD
# ---------------
args = self.build_train_args(batch, batch_idx, opt_idx, hiddens)
# distributed forward
if self.use_ddp or self.use_ddp2 or self.use_dp:
output = self.accelerator_backend.training_step(args)
# horovod
elif self.use_horovod:
output = self.accelerator_backend.training_step(args)
# single GPU forward
elif self.use_single_gpu:
output = self.accelerator_backend.training_step(args)
# TPU support
elif self.use_tpu:
output = self.accelerator_backend.training_step(args)
# CPU forward
else:
output = self.accelerator_backend.training_step(args)
output = self.accelerator_backend.training_step(args)
is_result_obj = isinstance(output, Result)