fix tests
This commit is contained in:
parent
e633787a3d
commit
d2c2e5004d
|
@ -47,6 +47,7 @@ CHECKPOINT_EXTENSION = ".ckpt"
|
|||
"1.1.3",
|
||||
"1.1.4",
|
||||
"1.1.5",
|
||||
"1.1.6",
|
||||
])
|
||||
def test_resume_legacy_checkpoints(tmpdir, pl_version):
|
||||
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
|
||||
|
|
|
@ -94,11 +94,10 @@ def test_training_epoch_end_metrics_collection(tmpdir):
|
|||
|
||||
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
|
||||
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
|
||||
num_epochs = 1
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
self.len_outputs = 0
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
|
@ -110,7 +109,6 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir):
|
|||
self.num_train_batches = 0
|
||||
|
||||
def training_epoch_end(self, outputs): # Overridden
|
||||
pass
|
||||
return
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -129,19 +127,20 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir):
|
|||
|
||||
callback = LoggingCallback()
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir,
|
||||
overfit_batches=2,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
result = trainer.fit(overridden_model)
|
||||
trainer.fit(overridden_model)
|
||||
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
|
||||
# if training_epoch_end is overridden
|
||||
assert callback.len_outputs == overridden_model.num_train_batches
|
||||
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
|
||||
|
||||
result = trainer.fit(not_overridden_model)
|
||||
assert callback.len_outputs == 0
|
||||
trainer.fit(not_overridden_model)
|
||||
# outputs from on_train_batch_end should be empty
|
||||
assert callback.len_outputs == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
|
|
Loading…
Reference in New Issue