resolve test

This commit is contained in:
tchaton 2020-11-27 19:34:45 +00:00
parent 8e51543af9
commit c6502adba1
1 changed files with 4 additions and 3 deletions

View File

@ -583,7 +583,6 @@ def test_log_works_in_train_callback(tmpdir):
def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_train_batch_end', 7, on_steps=self.choices,
@ -597,12 +596,12 @@ def test_log_works_in_train_callback(tmpdir):
def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_train_epoch_end(self, trainer, pl_module, outputs):
self.make_logging(pl_module, 'on_train_epoch_end', 9, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
class TestModel(BoringModel):
manual_loss = []
@ -639,8 +638,10 @@ def test_log_works_in_train_callback(tmpdir):
assert test_callback.funcs_called_count["on_train_batch_start"] == 4
assert test_callback.funcs_called_count["on_batch_end"] == 4
assert test_callback.funcs_called_count["on_epoch_end"] == 2
"""
assert test_callback.funcs_called_count["on_train_batch_end"] == 4
"""
assert test_callback.funcs_called_count["on_epoch_end"] == 2
assert test_callback.funcs_called_count["on_train_epoch_end"] == 2
# Make sure the func_name exists within callback_metrics. If not, we missed some