unify model test acc (#696)
This commit is contained in:
parent
deb1581e26
commit
bde549cb36
|
@ -114,7 +114,7 @@ def test_running_test_after_fitting(tmpdir):
|
|||
trainer.test()
|
||||
|
||||
# test we have good test accuracy
|
||||
tutils.assert_ok_test_acc(trainer)
|
||||
tutils.assert_ok_model_acc(trainer)
|
||||
|
||||
|
||||
def test_running_test_without_val(tmpdir):
|
||||
|
@ -152,7 +152,7 @@ def test_running_test_without_val(tmpdir):
|
|||
trainer.test()
|
||||
|
||||
# test we have good test accuracy
|
||||
tutils.assert_ok_test_acc(trainer)
|
||||
tutils.assert_ok_model_acc(trainer)
|
||||
|
||||
|
||||
def test_single_gpu_batch_parse():
|
||||
|
|
|
@ -93,7 +93,7 @@ def test_running_test_pretrained_model(tmpdir):
|
|||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.assert_ok_model_acc(new_trainer)
|
||||
|
||||
|
||||
def test_load_model_from_checkpoint(tmpdir):
|
||||
|
@ -134,7 +134,7 @@ def test_load_model_from_checkpoint(tmpdir):
|
|||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.assert_ok_model_acc(new_trainer)
|
||||
|
||||
|
||||
def test_running_test_pretrained_model_dp(tmpdir):
|
||||
|
@ -178,7 +178,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
|
|||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.assert_ok_model_acc(new_trainer)
|
||||
|
||||
|
||||
def test_dp_resume(tmpdir):
|
||||
|
|
|
@ -183,16 +183,10 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
|
|||
assert acc > min_acc, f'this model is expected to get > {min_acc} in test set (it got {acc})'
|
||||
|
||||
|
||||
def assert_ok_val_acc(trainer):
|
||||
def assert_ok_model_acc(trainer, key='test_acc', thr=0.4):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.training_tqdm_dict['val_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
def assert_ok_test_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.training_tqdm_dict['test_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
acc = trainer.training_tqdm_dict[key]
|
||||
assert acc > thr, f'Model failed to get expected {thr} accuracy. {key} = {acc}'
|
||||
|
||||
|
||||
def can_run_gpu_test():
|
||||
|
@ -208,9 +202,9 @@ def can_run_gpu_test():
|
|||
|
||||
|
||||
def reset_seed():
|
||||
SEED = RANDOM_SEEDS.pop()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
seed = RANDOM_SEEDS.pop()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def set_random_master_port():
|
||||
|
|
Loading…
Reference in New Issue