added test docs
This commit is contained in:
parent
b776fce2e7
commit
8bbd65c95d
|
@ -37,7 +37,7 @@ def get_exp():
|
|||
|
||||
def assert_ok_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
assert trainer.tng_tqdm_dic['val_acc'] > 0.80
|
||||
assert trainer.tng_tqdm_dic['val_acc'] > 0.80, "model failed to get expected 0.80 validation accuracy"
|
||||
|
||||
|
||||
def test_cpu_model():
|
||||
|
@ -55,11 +55,8 @@ def test_cpu_model():
|
|||
)
|
||||
|
||||
result = trainer.fit(model)
|
||||
assert result == 1, 'cpu model failed to complete'
|
||||
|
||||
metrics = trainer.tng_tqdm_dic
|
||||
print(metrics)
|
||||
|
||||
assert result == 1
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
@ -84,7 +81,7 @@ def test_single_gpu_model():
|
|||
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
assert result == 1, 'single gpu model failed to complete'
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
@ -112,7 +109,7 @@ def test_multi_gpu_model_dp():
|
|||
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
assert result == 1, 'multi-gpu dp model failed to complete'
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
@ -141,7 +138,7 @@ def test_multi_gpu_model_ddp():
|
|||
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
assert result == 1, 'multi-gpu ddp model failed to complete'
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
@ -171,7 +168,7 @@ def test_amp_gpu_ddp():
|
|||
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
@ -201,7 +198,7 @@ def test_amp_gpu_dp():
|
|||
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
assert result == 1, 'amp + gpu model failed to complete'
|
||||
assert_ok_acc(trainer)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue