Improve skipping step tests (#4109)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
Carlos Mocholí 2020-11-14 22:10:24 +01:00 committed by GitHub
parent 504a669015
commit 61394d543c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 3 deletions

View File

@ -217,17 +217,56 @@ def test_train_step_no_return(tmpdir):
def training_epoch_end(self, outputs) -> None:
assert len(outputs) == 0
model = TestModel()
model.val_dataloader = None
def validation_step(self, batch, batch_idx):
self.validation_step_called = True
def validation_epoch_end(self, outputs):
assert len(outputs) == 0
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
)
with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
trainer.fit(model)
assert model.training_step_called
assert model.validation_step_called
def test_training_step_no_return_when_even(tmpdir):
"""
Tests correctness when some training steps have been skipped
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.training_step_called = True
loss = self.step(batch[0])
self.log('a', loss, on_step=True, on_epoch=True)
return loss if batch_idx % 2 else None
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
limit_val_batches=1,
max_epochs=4,
weights_summary=None,
logger=False,
checkpoint_callback=False,
)
with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
trainer.fit(model)
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
if not batch_idx % 2:
assert out.training_step_output_for_epoch_end == [[]]
assert out.signal == 0