lightning/pytorch_lightning/metrics/utils.py

20 lines
302 B
Python

import torch
from typing import Any, Callable, Optional, Union
def dim_zero_cat(x):
return torch.cat(x, dim=0)
def dim_zero_sum(x):
return torch.sum(x, dim=0)
def dim_zero_mean(x):
return torch.mean(x, dim=0)
def _flatten(x):
return [item for sublist in x for item in sublist]