2020-04-16 02:16:40 +00:00
|
|
|
from abc import ABC
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class ValidationEpochEndVariations(ABC):
|
|
|
|
"""
|
|
|
|
Houses all variations of validation_epoch_end steps
|
|
|
|
"""
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
|
|
"""
|
|
|
|
Called at the end of validation to aggregate outputs
|
|
|
|
|
|
|
|
Args:
|
|
|
|
outputs: list of individual outputs of each validation step
|
|
|
|
"""
|
|
|
|
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
|
|
|
# we return just the average in this case (if we want)
|
2020-05-04 20:51:39 +00:00
|
|
|
def _mean(res, key):
|
|
|
|
# recursive mean for multilevel dicts
|
|
|
|
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()
|
|
|
|
|
2020-04-16 02:16:40 +00:00
|
|
|
# return torch.stack(outputs).mean()
|
2020-05-04 20:51:39 +00:00
|
|
|
val_loss_mean = _mean(outputs, 'val_loss')
|
|
|
|
val_acc_mean = _mean(outputs, 'val_acc')
|
2020-04-16 02:16:40 +00:00
|
|
|
for output in outputs:
|
|
|
|
val_loss = self.get_output_metric(output, 'val_loss')
|
|
|
|
|
|
|
|
# reduce manually when using dp
|
|
|
|
if self.trainer.use_dp or self.trainer.use_ddp2:
|
|
|
|
val_loss = torch.mean(val_loss)
|
|
|
|
val_loss_mean += val_loss
|
|
|
|
|
|
|
|
# reduce manually when using dp
|
|
|
|
val_acc = self.get_output_metric(output, 'val_acc')
|
|
|
|
if self.trainer.use_dp or self.trainer.use_ddp2:
|
|
|
|
val_acc = torch.mean(val_acc)
|
|
|
|
|
|
|
|
val_acc_mean += val_acc
|
|
|
|
|
2020-05-04 05:25:48 +00:00
|
|
|
if outputs: # skip zero divisions
|
|
|
|
val_loss_mean /= len(outputs)
|
|
|
|
val_acc_mean /= len(outputs)
|
2020-04-16 02:16:40 +00:00
|
|
|
|
|
|
|
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
|
|
|
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
|
|
|
return results
|