updated args
This commit is contained in:
parent
51305697c1
commit
f49c2f4c25
|
@ -110,13 +110,13 @@ class ExampleModel(RootModule):
|
|||
:return:
|
||||
"""
|
||||
val_loss_mean = 0
|
||||
accs = []
|
||||
val_acc_mean = 0
|
||||
for output in outputs:
|
||||
val_loss_mean += output['val_loss']
|
||||
accs.append(output['val_acc'])
|
||||
val_acc_mean += output['val_acc']
|
||||
|
||||
val_loss_mean /= len(outputs)
|
||||
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': torch.mean(accs).item()}
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
|
|
Loading…
Reference in New Issue