From d2c2e5004d8e8e3c2616178616c0bef421636cda Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 4 Feb 2021 18:56:45 +0100 Subject: [PATCH] fix tests --- tests/checkpointing/test_legacy_checkpoints.py | 1 + tests/models/test_hooks.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 577362e65f..9bde704256 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -47,6 +47,7 @@ CHECKPOINT_EXTENSION = ".ckpt" "1.1.3", "1.1.4", "1.1.5", + "1.1.6", ]) def test_resume_legacy_checkpoints(tmpdir, pl_version): path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9f9b03db4c..b0a69eaeda 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -94,11 +94,10 @@ def test_training_epoch_end_metrics_collection(tmpdir): def test_training_epoch_end_metrics_collection_on_override(tmpdir): """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ - num_epochs = 1 class LoggingCallback(Callback): - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self.len_outputs = 0 def on_train_epoch_end(self, trainer, pl_module, outputs): @@ -110,7 +109,6 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir): self.num_train_batches = 0 def training_epoch_end(self, outputs): # Overridden - pass return def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): @@ -129,19 +127,20 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir): callback = LoggingCallback() trainer = Trainer( - max_epochs=num_epochs, + max_epochs=1, default_root_dir=tmpdir, overfit_batches=2, callbacks=[callback], ) - result = trainer.fit(overridden_model) + trainer.fit(overridden_model) + # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook + # if training_epoch_end is overridden assert callback.len_outputs == overridden_model.num_train_batches - # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden - result = trainer.fit(not_overridden_model) - assert callback.len_outputs == 0 + trainer.fit(not_overridden_model) # outputs from on_train_batch_end should be empty + assert callback.len_outputs == 0 @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")