diff --git a/tests/test_models.py b/tests/test_models.py index 0e93338250..152fdfef18 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -462,13 +462,13 @@ def run_prediction(dataloader, trained_model): print(val_acc) - assert val_acc > 0.55, f'this model is expected to get > 0.55 in test set (it got {val_acc})' + assert val_acc > 0.50, f'this model is expected to get > 0.50 in test set (it got {val_acc})' def assert_ok_acc(trainer): # this model should get 0.80+ acc acc = trainer.tng_tqdm_dic['val_acc'] - assert acc > 0.55, f'model failed to get expected 0.55 validation accuracy. Got: {acc}' + assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' if __name__ == '__main__':