20 lines
302 B
Python
20 lines
302 B
Python
|
import torch
|
||
|
|
||
|
from typing import Any, Callable, Optional, Union
|
||
|
|
||
|
|
||
|
def dim_zero_cat(x):
|
||
|
return torch.cat(x, dim=0)
|
||
|
|
||
|
|
||
|
def dim_zero_sum(x):
|
||
|
return torch.sum(x, dim=0)
|
||
|
|
||
|
|
||
|
def dim_zero_mean(x):
|
||
|
return torch.mean(x, dim=0)
|
||
|
|
||
|
|
||
|
def _flatten(x):
|
||
|
return [item for sublist in x for item in sublist]
|