simplify trainer output
This commit is contained in:
parent
cc12a1c8fa
commit
0929908229
|
@ -611,10 +611,16 @@ class Trainer(TrainerIO):
|
|||
|
||||
try:
|
||||
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
|
||||
except TypeError as e:
|
||||
except Exception as e:
|
||||
model_specific_tqdm_metrics_dic = {}
|
||||
|
||||
loss = output['loss']
|
||||
# if output dict doesn't have the keyword loss
|
||||
# then assume the output=loss if scalar
|
||||
try:
|
||||
loss = output['loss']
|
||||
except Exception as e:
|
||||
if type(loss) is torch.Tensor:
|
||||
loss = output
|
||||
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
|
|
Loading…
Reference in New Issue