removed self.model refs

This commit is contained in:
William Falcon 2019-06-26 18:17:40 -04:00
parent bf0f5a5cbb
commit 5c8875130b
1 changed files with 3 additions and 4 deletions

View File

@ -407,12 +407,11 @@ class Trainer(TrainerIO):
# forward pass
# return a scalar value and a dic with tqdm metrics
output = self.model(data_batch, batch_nb)
# when DP, we need to aggregate the scalars we received as outputs
# use mean as the reduce function
if self.data_parallel:
output = self.model(data_batch, batch_nb)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = self.model.training_step(data_batch, batch_nb)
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
loss = output['loss']