reduce to 0.22
This commit is contained in:
parent
b36b9a0145
commit
8e51543af9
|
@ -12,7 +12,7 @@ from tests.base.models import ParityModuleMNIST, ParityModuleRNN
|
|||
# ParityModuleMNIST runs with num_workers=1
|
||||
@pytest.mark.parametrize('cls_model,max_diff', [
|
||||
(ParityModuleRNN, 0.05),
|
||||
(ParityModuleMNIST, 0.18)
|
||||
(ParityModuleMNIST, 0.22)
|
||||
])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_pytorch_parity(tmpdir, cls_model, max_diff):
|
||||
|
|
Loading…
Reference in New Issue