added cpu + amp error
This commit is contained in:
parent
fcda19aa25
commit
a5756d91be
|
@ -462,13 +462,13 @@ def run_prediction(dataloader, trained_model):
|
|||
|
||||
print(val_acc)
|
||||
|
||||
assert val_acc > 0.55, f'this model is expected to get > 0.55 in test set (it got {val_acc})'
|
||||
assert val_acc > 0.50, f'this model is expected to get > 0.50 in test set (it got {val_acc})'
|
||||
|
||||
|
||||
def assert_ok_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.tng_tqdm_dic['val_acc']
|
||||
assert acc > 0.55, f'model failed to get expected 0.55 validation accuracy. Got: {acc}'
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue