added test docs

This commit is contained in:
William Falcon 2019-07-24 09:04:36 -04:00
parent b776fce2e7
commit 8bbd65c95d
1 changed files with 7 additions and 10 deletions

View File

@ -37,7 +37,7 @@ def get_exp():
def assert_ok_acc(trainer):
# this model should get 0.80+ acc
assert trainer.tng_tqdm_dic['val_acc'] > 0.80
assert trainer.tng_tqdm_dic['val_acc'] > 0.80, "model failed to get expected 0.80 validation accuracy"
def test_cpu_model():
@ -55,11 +55,8 @@ def test_cpu_model():
)
result = trainer.fit(model)
assert result == 1, 'cpu model failed to complete'
metrics = trainer.tng_tqdm_dic
print(metrics)
assert result == 1
assert_ok_acc(trainer)
@ -84,7 +81,7 @@ def test_single_gpu_model():
result = trainer.fit(model)
assert result == 1
assert result == 1, 'single gpu model failed to complete'
assert_ok_acc(trainer)
@ -112,7 +109,7 @@ def test_multi_gpu_model_dp():
result = trainer.fit(model)
assert result == 1
assert result == 1, 'multi-gpu dp model failed to complete'
assert_ok_acc(trainer)
@ -141,7 +138,7 @@ def test_multi_gpu_model_ddp():
result = trainer.fit(model)
assert result == 1
assert result == 1, 'multi-gpu ddp model failed to complete'
assert_ok_acc(trainer)
@ -171,7 +168,7 @@ def test_amp_gpu_ddp():
result = trainer.fit(model)
assert result == 1
assert result == 1, 'amp + ddp model failed to complete'
assert_ok_acc(trainer)
@ -201,7 +198,7 @@ def test_amp_gpu_dp():
result = trainer.fit(model)
assert result == 1
assert result == 1, 'amp + gpu model failed to complete'
assert_ok_acc(trainer)