40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
![]() |
import torch
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
Tensor = torch.cuda.FloatTensor
|
||
|
FloatTensor = torch.cuda.FloatTensor
|
||
|
LongTensor = torch.cuda.LongTensor
|
||
|
ByteTensor = torch.cuda.ByteTensor
|
||
|
else:
|
||
|
Tensor = torch.Tensor
|
||
|
FloatTensor = torch.FloatTensor
|
||
|
LongTensor = torch.LongTensor
|
||
|
ByteTensor = torch.ByteTensor
|
||
|
|
||
|
def CUDA_wrapper(tensor):
|
||
|
use_cuda = torch.cuda.is_available()
|
||
|
if use_cuda:
|
||
|
return tensor.cuda()
|
||
|
else:
|
||
|
return tensor
|
||
|
|
||
|
class SoftmaxWithTemperature:
|
||
|
def __init__(self, temperature):
|
||
|
"""
|
||
|
formula: softmax(x/temperature)
|
||
|
"""
|
||
|
self.temperature = temperature
|
||
|
self.softmax = torch.nn.Softmax()
|
||
|
|
||
|
def __call__(self, x, temperature=None):
|
||
|
if not temperature is None:
|
||
|
return self.softmax(x / temperature)
|
||
|
else:
|
||
|
return self.softmax(x / self.temperature)
|
||
|
|
||
|
def fill_eye_diag(a):
|
||
|
_, s1, s2 = a.data.shape
|
||
|
dd = Variable(CUDA_wrapper(torch.eye(s1)))
|
||
|
zero_dd = 1 - dd
|
||
|
return a * zero_dd + dd
|