lightning/tests/models/test_module_hooks.py

48 lines
1.6 KiB
Python

import torch
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
import tests.base.utils as tutils
def test_training_epoch_end_metrics_collection(tmpdir):
""" Test that progress bar metrics also get collected at the end of an epoch. """
num_epochs = 3
class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output['progress_bar'].update({'step_metric': torch.tensor(-1)})
output['progress_bar'].update({'shared_metric': 100})
return output
def training_epoch_end(self, outputs):
epoch = self.current_epoch
# both scalar tensors and Python numbers are accepted
return {
'progress_bar': {
f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch
'shared_metric': 111,
}
}
model = CurrentModel()
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_pct=0.1,
)
result = trainer.fit(model)
assert result == 1
metrics = trainer.progress_bar_dict
# metrics added in training step should be unchanged by epoch end method
assert metrics['step_metric'] == -1
# a metric shared in both methods gets overwritten by epoch_end
assert metrics['shared_metric'] == 111
# metrics are kept after each epoch
for i in range(num_epochs):
assert metrics[f'epoch_metric_{i}'] == i