40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
Tensor = torch.cuda.FloatTensor
|
|
FloatTensor = torch.cuda.FloatTensor
|
|
LongTensor = torch.cuda.LongTensor
|
|
ByteTensor = torch.cuda.ByteTensor
|
|
else:
|
|
Tensor = torch.Tensor
|
|
FloatTensor = torch.FloatTensor
|
|
LongTensor = torch.LongTensor
|
|
ByteTensor = torch.ByteTensor
|
|
|
|
def CUDA_wrapper(tensor):
|
|
use_cuda = torch.cuda.is_available()
|
|
if use_cuda:
|
|
return tensor.cuda()
|
|
else:
|
|
return tensor
|
|
|
|
class SoftmaxWithTemperature:
|
|
def __init__(self, temperature):
|
|
"""
|
|
formula: softmax(x/temperature)
|
|
"""
|
|
self.temperature = temperature
|
|
self.softmax = torch.nn.Softmax()
|
|
|
|
def __call__(self, x, temperature=None):
|
|
if not temperature is None:
|
|
return self.softmax(x / temperature)
|
|
else:
|
|
return self.softmax(x / self.temperature)
|
|
|
|
def fill_eye_diag(a):
|
|
_, s1, s2 = a.data.shape
|
|
dd = Variable(CUDA_wrapper(torch.eye(s1)))
|
|
zero_dd = 1 - dd
|
|
return a * zero_dd + dd
|