fix tests

This commit is contained in:
Jirka Borovec 2021-02-04 18:56:45 +01:00
parent e633787a3d
commit d2c2e5004d
2 changed files with 8 additions and 8 deletions

View File

@ -47,6 +47,7 @@ CHECKPOINT_EXTENSION = ".ckpt"
"1.1.3",
"1.1.4",
"1.1.5",
"1.1.6",
])
def test_resume_legacy_checkpoints(tmpdir, pl_version):
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)

View File

@ -94,11 +94,10 @@ def test_training_epoch_end_metrics_collection(tmpdir):
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
num_epochs = 1
class LoggingCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self.len_outputs = 0
def on_train_epoch_end(self, trainer, pl_module, outputs):
@ -110,7 +109,6 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir):
self.num_train_batches = 0
def training_epoch_end(self, outputs): # Overridden
pass
return
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
@ -129,19 +127,20 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir):
callback = LoggingCallback()
trainer = Trainer(
max_epochs=num_epochs,
max_epochs=1,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)
result = trainer.fit(overridden_model)
trainer.fit(overridden_model)
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
# if training_epoch_end is overridden
assert callback.len_outputs == overridden_model.num_train_batches
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
result = trainer.fit(not_overridden_model)
assert callback.len_outputs == 0
trainer.fit(not_overridden_model)
# outputs from on_train_batch_end should be empty
assert callback.len_outputs == 0
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")