ref: reduced all simplified_forward (#3126)
* simplified training_forward * simplified training_forward * simplified training_forward
This commit is contained in:
parent
6068b29d29
commit
20018b2668
|
@ -1162,16 +1162,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
:param batch:
|
||||
:param batch_idx:
|
||||
:return:
|
||||
"""
|
||||
# ---------------
|
||||
# FORWARD
|
||||
# ---------------
|
||||
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
|
||||
# enable not needing to add opt_idx to training_step
|
||||
args = [batch, batch_idx]
|
||||
|
||||
|
@ -1189,24 +1180,21 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.truncated_bptt_steps is not None:
|
||||
args.append(hiddens)
|
||||
|
||||
return args
|
||||
|
||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
:param batch:
|
||||
:param batch_idx:
|
||||
:return:
|
||||
"""
|
||||
# ---------------
|
||||
# FORWARD
|
||||
# ---------------
|
||||
args = self.build_train_args(batch, batch_idx, opt_idx, hiddens)
|
||||
|
||||
# distributed forward
|
||||
if self.use_ddp or self.use_ddp2 or self.use_dp:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# horovod
|
||||
elif self.use_horovod:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# single GPU forward
|
||||
elif self.use_single_gpu:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# TPU support
|
||||
elif self.use_tpu:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# CPU forward
|
||||
else:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
|
Loading…
Reference in New Issue