unify model test acc (#696)

This commit is contained in:
Jirka Borovec 2020-01-17 11:50:26 +01:00 committed by William Falcon
parent deb1581e26
commit bde549cb36
3 changed files with 11 additions and 17 deletions

View File

@ -114,7 +114,7 @@ def test_running_test_after_fitting(tmpdir):
trainer.test()
# test we have good test accuracy
tutils.assert_ok_test_acc(trainer)
tutils.assert_ok_model_acc(trainer)
def test_running_test_without_val(tmpdir):
@ -152,7 +152,7 @@ def test_running_test_without_val(tmpdir):
trainer.test()
# test we have good test accuracy
tutils.assert_ok_test_acc(trainer)
tutils.assert_ok_model_acc(trainer)
def test_single_gpu_batch_parse():

View File

@ -93,7 +93,7 @@ def test_running_test_pretrained_model(tmpdir):
new_trainer.test(pretrained_model)
# test we have good test accuracy
tutils.assert_ok_test_acc(new_trainer)
tutils.assert_ok_model_acc(new_trainer)
def test_load_model_from_checkpoint(tmpdir):
@ -134,7 +134,7 @@ def test_load_model_from_checkpoint(tmpdir):
new_trainer.test(pretrained_model)
# test we have good test accuracy
tutils.assert_ok_test_acc(new_trainer)
tutils.assert_ok_model_acc(new_trainer)
def test_running_test_pretrained_model_dp(tmpdir):
@ -178,7 +178,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
new_trainer.test(pretrained_model)
# test we have good test accuracy
tutils.assert_ok_test_acc(new_trainer)
tutils.assert_ok_model_acc(new_trainer)
def test_dp_resume(tmpdir):

View File

@ -183,16 +183,10 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
assert acc > min_acc, f'this model is expected to get > {min_acc} in test set (it got {acc})'
def assert_ok_val_acc(trainer):
def assert_ok_model_acc(trainer, key='test_acc', thr=0.4):
# this model should get 0.80+ acc
acc = trainer.training_tqdm_dict['val_acc']
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
def assert_ok_test_acc(trainer):
# this model should get 0.80+ acc
acc = trainer.training_tqdm_dict['test_acc']
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
acc = trainer.training_tqdm_dict[key]
assert acc > thr, f'Model failed to get expected {thr} accuracy. {key} = {acc}'
def can_run_gpu_test():
@ -208,9 +202,9 @@ def can_run_gpu_test():
def reset_seed():
SEED = RANDOM_SEEDS.pop()
torch.manual_seed(SEED)
np.random.seed(SEED)
seed = RANDOM_SEEDS.pop()
torch.manual_seed(seed)
np.random.seed(seed)
def set_random_master_port():