diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 03fe976c44..475e6ec5f2 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -114,7 +114,7 @@ def test_running_test_after_fitting(tmpdir): trainer.test() # test we have good test accuracy - tutils.assert_ok_test_acc(trainer) + tutils.assert_ok_model_acc(trainer) def test_running_test_without_val(tmpdir): @@ -152,7 +152,7 @@ def test_running_test_without_val(tmpdir): trainer.test() # test we have good test accuracy - tutils.assert_ok_test_acc(trainer) + tutils.assert_ok_model_acc(trainer) def test_single_gpu_batch_parse(): diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 02d57b3155..498b312913 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -93,7 +93,7 @@ def test_running_test_pretrained_model(tmpdir): new_trainer.test(pretrained_model) # test we have good test accuracy - tutils.assert_ok_test_acc(new_trainer) + tutils.assert_ok_model_acc(new_trainer) def test_load_model_from_checkpoint(tmpdir): @@ -134,7 +134,7 @@ def test_load_model_from_checkpoint(tmpdir): new_trainer.test(pretrained_model) # test we have good test accuracy - tutils.assert_ok_test_acc(new_trainer) + tutils.assert_ok_model_acc(new_trainer) def test_running_test_pretrained_model_dp(tmpdir): @@ -178,7 +178,7 @@ def test_running_test_pretrained_model_dp(tmpdir): new_trainer.test(pretrained_model) # test we have good test accuracy - tutils.assert_ok_test_acc(new_trainer) + tutils.assert_ok_model_acc(new_trainer) def test_dp_resume(tmpdir): diff --git a/tests/utils.py b/tests/utils.py index baec6274b3..70ecccb28d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -183,16 +183,10 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): assert acc > min_acc, f'this model is expected to get > {min_acc} in test set (it got {acc})' -def assert_ok_val_acc(trainer): +def assert_ok_model_acc(trainer, key='test_acc', thr=0.4): # this model should get 0.80+ acc - acc = trainer.training_tqdm_dict['val_acc'] - assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' - - -def assert_ok_test_acc(trainer): - # this model should get 0.80+ acc - acc = trainer.training_tqdm_dict['test_acc'] - assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' + acc = trainer.training_tqdm_dict[key] + assert acc > thr, f'Model failed to get expected {thr} accuracy. {key} = {acc}' def can_run_gpu_test(): @@ -208,9 +202,9 @@ def can_run_gpu_test(): def reset_seed(): - SEED = RANDOM_SEEDS.pop() - torch.manual_seed(SEED) - np.random.seed(SEED) + seed = RANDOM_SEEDS.pop() + torch.manual_seed(seed) + np.random.seed(seed) def set_random_master_port():