Improve skipping step tests (#4109)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
504a669015
commit
61394d543c
|
@ -217,17 +217,56 @@ def test_train_step_no_return(tmpdir):
|
|||
def training_epoch_end(self, outputs) -> None:
|
||||
assert len(outputs) == 0
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.validation_step_called = True
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
assert len(outputs) == 0
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_epochs=2,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
|
||||
trainer.fit(model)
|
||||
assert model.training_step_called
|
||||
assert model.validation_step_called
|
||||
|
||||
|
||||
def test_training_step_no_return_when_even(tmpdir):
|
||||
"""
|
||||
Tests correctness when some training steps have been skipped
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.training_step_called = True
|
||||
loss = self.step(batch[0])
|
||||
self.log('a', loss, on_step=True, on_epoch=True)
|
||||
return loss if batch_idx % 2 else None
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=4,
|
||||
limit_val_batches=1,
|
||||
max_epochs=4,
|
||||
weights_summary=None,
|
||||
logger=False,
|
||||
checkpoint_callback=False,
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
|
||||
trainer.fit(model)
|
||||
|
||||
# manually check a few batches
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
|
||||
if not batch_idx % 2:
|
||||
assert out.training_step_output_for_epoch_end == [[]]
|
||||
assert out.signal == 0
|
||||
|
|
Loading…
Reference in New Issue