reduce parity to 0.22

This commit is contained in:
tchaton 2020-11-27 18:36:18 +00:00
parent cef83dbbf8
commit 1f1a20c45f
2 changed files with 8 additions and 8 deletions

View File

@ -12,7 +12,7 @@ from tests.base.models import ParityModuleMNIST, ParityModuleRNN
# TODO: explore where the time leak comes from # TODO: explore where the time leak comes from
@pytest.mark.parametrize('cls_model,max_diff', [ @pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05), (ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.99) (ParityModuleMNIST, 0.22)
]) ])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff): def test_pytorch_parity(tmpdir, cls_model, max_diff):

View File

@ -117,8 +117,8 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
def validation_step_end(self, acc): def validation_step_end(self, acc):
self.validation_step_end_called = True self.validation_step_end_called = True
#self.log('e', acc) # self.log('e', acc)
#self.log('f', acc, on_step=True, on_epoch=True) # self.log('f', acc, on_step=True, on_epoch=True)
return ['random_thing'] return ['random_thing']
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
@ -151,10 +151,10 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
'd_step/epoch_0', 'd_step/epoch_0',
'd_step/epoch_1', 'd_step/epoch_1',
'd_epoch', 'd_epoch',
#'e', # 'e',
#'f_step/epoch_0', # 'f_step/epoch_0',
#'f_step/epoch_1', # 'f_step/epoch_1',
#'f_epoch', # 'f_epoch',
'g', 'g',
} }
assert expected_logged_metrics == logged_metrics assert expected_logged_metrics == logged_metrics
@ -167,7 +167,7 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
callback_metrics = set(trainer.callback_metrics.keys()) callback_metrics = set(trainer.callback_metrics.keys())
callback_metrics.remove('debug_epoch') callback_metrics.remove('debug_epoch')
expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'} expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'}
#expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'b_epoch', 'd_epoch', 'f_epoch', 'f', 'g', 'b_step'} # expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'b_epoch', 'd_epoch', 'f_epoch', 'f', 'g', 'b_step'}
assert expected_cb_metrics == callback_metrics assert expected_cb_metrics == callback_metrics