From 61394d543c259ab2db8f423807c02e6d292b27c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 14 Nov 2020 22:10:24 +0100 Subject: [PATCH] Improve skipping step tests (#4109) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: chaton Co-authored-by: Sean Naren --- .../test_train_loop_flow_scalar_1_0.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py index c3d4d56477..b8211b55f5 100644 --- a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py +++ b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py @@ -217,17 +217,56 @@ def test_train_step_no_return(tmpdir): def training_epoch_end(self, outputs) -> None: assert len(outputs) == 0 - model = TestModel() - model.val_dataloader = None + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + def validation_epoch_end(self, outputs): + assert len(outputs) == 0 + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, - max_epochs=1, + max_epochs=2, log_every_n_steps=1, weights_summary=None, ) with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): trainer.fit(model) + assert model.training_step_called + assert model.validation_step_called + + +def test_training_step_no_return_when_even(tmpdir): + """ + Tests correctness when some training steps have been skipped + """ + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + self.training_step_called = True + loss = self.step(batch[0]) + self.log('a', loss, on_step=True, on_epoch=True) + return loss if batch_idx % 2 else None + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + limit_val_batches=1, + max_epochs=4, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + + with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): + trainer.fit(model) + + # manually check a few batches + for batch_idx, batch in enumerate(model.train_dataloader()): + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + if not batch_idx % 2: + assert out.training_step_output_for_epoch_end == [[]] + assert out.signal == 0