40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
from abc import ABC
|
|
|
|
import torch
|
|
|
|
|
|
class TestEpochEndVariations(ABC):
|
|
|
|
def test_epoch_end(self, outputs):
|
|
"""
|
|
Called at the end of validation to aggregate outputs
|
|
:param outputs: list of individual outputs of each validation step
|
|
:return:
|
|
"""
|
|
# if returned a scalar from test_step, outputs is a list of tensor scalars
|
|
# we return just the average in this case (if we want)
|
|
# return torch.stack(outputs).mean()
|
|
test_loss_mean = 0
|
|
test_acc_mean = 0
|
|
for output in outputs:
|
|
test_loss = self.get_output_metric(output, 'test_loss')
|
|
|
|
# reduce manually when using dp
|
|
if self.trainer.use_dp:
|
|
test_loss = torch.mean(test_loss)
|
|
test_loss_mean += test_loss
|
|
|
|
# reduce manually when using dp
|
|
test_acc = self.get_output_metric(output, 'test_acc')
|
|
if self.trainer.use_dp:
|
|
test_acc = torch.mean(test_acc)
|
|
|
|
test_acc_mean += test_acc
|
|
|
|
test_loss_mean /= len(outputs)
|
|
test_acc_mean /= len(outputs)
|
|
|
|
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
|
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
|
return result
|