reduce parity to 0.22
This commit is contained in:
parent
cef83dbbf8
commit
1f1a20c45f
|
@ -12,7 +12,7 @@ from tests.base.models import ParityModuleMNIST, ParityModuleRNN
|
||||||
# TODO: explore where the time leak comes from
|
# TODO: explore where the time leak comes from
|
||||||
@pytest.mark.parametrize('cls_model,max_diff', [
|
@pytest.mark.parametrize('cls_model,max_diff', [
|
||||||
(ParityModuleRNN, 0.05),
|
(ParityModuleRNN, 0.05),
|
||||||
(ParityModuleMNIST, 0.99)
|
(ParityModuleMNIST, 0.22)
|
||||||
])
|
])
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||||
def test_pytorch_parity(tmpdir, cls_model, max_diff):
|
def test_pytorch_parity(tmpdir, cls_model, max_diff):
|
||||||
|
|
|
@ -117,8 +117,8 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
|
||||||
|
|
||||||
def validation_step_end(self, acc):
|
def validation_step_end(self, acc):
|
||||||
self.validation_step_end_called = True
|
self.validation_step_end_called = True
|
||||||
#self.log('e', acc)
|
# self.log('e', acc)
|
||||||
#self.log('f', acc, on_step=True, on_epoch=True)
|
# self.log('f', acc, on_step=True, on_epoch=True)
|
||||||
return ['random_thing']
|
return ['random_thing']
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
|
@ -151,10 +151,10 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
|
||||||
'd_step/epoch_0',
|
'd_step/epoch_0',
|
||||||
'd_step/epoch_1',
|
'd_step/epoch_1',
|
||||||
'd_epoch',
|
'd_epoch',
|
||||||
#'e',
|
# 'e',
|
||||||
#'f_step/epoch_0',
|
# 'f_step/epoch_0',
|
||||||
#'f_step/epoch_1',
|
# 'f_step/epoch_1',
|
||||||
#'f_epoch',
|
# 'f_epoch',
|
||||||
'g',
|
'g',
|
||||||
}
|
}
|
||||||
assert expected_logged_metrics == logged_metrics
|
assert expected_logged_metrics == logged_metrics
|
||||||
|
@ -167,7 +167,7 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
|
||||||
callback_metrics = set(trainer.callback_metrics.keys())
|
callback_metrics = set(trainer.callback_metrics.keys())
|
||||||
callback_metrics.remove('debug_epoch')
|
callback_metrics.remove('debug_epoch')
|
||||||
expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'}
|
expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'}
|
||||||
#expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'b_epoch', 'd_epoch', 'f_epoch', 'f', 'g', 'b_step'}
|
# expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'b_epoch', 'd_epoch', 'f_epoch', 'f', 'g', 'b_step'}
|
||||||
assert expected_cb_metrics == callback_metrics
|
assert expected_cb_metrics == callback_metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue