simplify trainer output

This commit is contained in:
William Falcon 2019-07-11 15:08:45 -04:00
parent cc12a1c8fa
commit 0929908229
1 changed files with 8 additions and 2 deletions

View File

@ -611,10 +611,16 @@ class Trainer(TrainerIO):
try:
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
except TypeError as e:
except Exception as e:
model_specific_tqdm_metrics_dic = {}
loss = output['loss']
# if output dict doesn't have the keyword loss
# then assume the output=loss if scalar
try:
loss = output['loss']
except Exception as e:
if type(loss) is torch.Tensor:
loss = output
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)