Kindai-OCR/utils.py

263 lines
9.1 KiB
Python

#!/usr/bin/env python
import numpy as np
import copy
import sys
import pickle as pkl
import torch
from torch import nn
# load data
def dataIterator(feature_file, label_file, dictionary, batch_size, batch_Imagesize, maxlen, maxImagesize):
# offline-train.pkl
fp = open(feature_file, 'rb')
features = pkl.load(fp, encoding='latin1')
fp.close()
# train_caption.txt
fp2 = open(label_file, 'r')
labels = fp2.readlines()
fp2.close()
targets = {}
# map word to int with dictionary
for l in labels:
tmp = l.strip().split()
uid = tmp[0]
w_list = []
for w in tmp[1:]:
if dictionary.__contains__(w):
w_list.append(dictionary[w])
else:
#print('a word not in the dictionary !! sentence ', uid, 'word ', w)
print(w + '\t' + str(len(dictionary)))
dictionary[w] = len(dictionary)
#sys.exit()
targets[uid] = w_list
imageSize = {}
for uid, fea in features.items():
imageSize[uid] = fea.shape[1] * fea.shape[2]
# sorted by sentence length, return a list with each triple element
imageSize = sorted(imageSize.items(), key=lambda d: d[1])
feature_batch = []
label_batch = []
feature_total = []
label_total = []
uidList = []
biggest_image_size = 0
i = 0
for uid, size in imageSize:
if size > biggest_image_size:
biggest_image_size = size
fea = features[uid]
lab = targets[uid]
batch_image_size = biggest_image_size * (i + 1)
if len(lab) > maxlen:
print('sentence', uid, 'length bigger than', maxlen, 'ignore')
elif size > maxImagesize:
print(size)
print('image', uid, 'size bigger than', maxImagesize, 'ignore')
else:
uidList.append(uid)
if batch_image_size > batch_Imagesize or i == batch_size: # a batch is full
feature_total.append(feature_batch)
label_total.append(label_batch)
i = 0
biggest_image_size = size
feature_batch = []
label_batch = []
feature_batch.append(fea)
label_batch.append(lab)
i += 1
else:
feature_batch.append(fea)
label_batch.append(lab)
i += 1
# last batch
feature_total.append(feature_batch)
label_total.append(label_batch)
print('total ', len(feature_total), 'batch data loaded')
return list(zip(feature_total, label_total)), uidList
# load dictionary
def load_dict(dictFile):
fp = open(dictFile)
stuff = fp.readlines()
fp.close()
lexicon = {}
for l in stuff:
w = l.strip().split()
lexicon[w[0]] = int(w[1])
print('total words/phones', len(lexicon))
return lexicon
# load mapping
def load_mapping(dictFile):
print (dictFile)
lexicon={}
lexicon_r ={}
with open(dictFile,'r') as f:
lines = f.readlines()
for line in lines:
sp = line.split()
lexicon[sp[1]]=unicode(sp[0], 'Shift_JISx0213')
lexicon_r[unicode(sp[0], 'Shift_JISx0213')]=sp[1]
print ('total words/phones',len(lexicon))
return lexicon, lexicon_r
# create batch
def prepare_data(options, images_x, seqs_y, prev_x = None):
'''if prev_x!= None and len(images_x) == len(prev_x) and np.random.random_sample() > 0.7:
for i in range(len(images_x)):
#print(np.shape(images_x[i]))
images_x[i] = images_x[i]*0.7 + prev_x[i] * 0.3'''
'''if np.random.random_sample() > 0.7:
for i in range(len(images_x)):
#print(np.shape(images_x[i][0]))
if np.shape(images_x[i][0])[0] <= 100 or np.shape(images_x[i][0])[1] <= 100: continue
img = Image.fromarray(images_x[i][0])
img = sk.perform_operation([img])
img = br.perform_operation([img[0]])
#img[0].save(str(i) + '.jpg')
img = np.asarray(img[0])
images_x[i][0] = img
#print(np.shape(images_x[i][0]))'''
heights_x = [s.shape[1] for s in images_x]
widths_x = [s.shape[2] for s in images_x]
lengths_y = [len(s) for s in seqs_y]
n_samples = len(heights_x)
max_height_x = np.max(heights_x)
max_width_x = np.max(widths_x)
maxlen_y = np.max(lengths_y) + 1
x = np.zeros((n_samples, options['input_channels'], max_height_x, max_width_x)).astype(np.float32)
y = np.zeros((maxlen_y, n_samples)).astype(np.int64) # <eos> must be 0 in the dict
x_mask = np.zeros((n_samples, max_height_x, max_width_x)).astype(np.float32)
y_mask = np.zeros((maxlen_y, n_samples)).astype(np.float32)
for idx, [s_x, s_y] in enumerate(zip(images_x, seqs_y)):
x[idx, :, :heights_x[idx], :widths_x[idx]] = s_x / 255.
x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1.
y[:lengths_y[idx], idx] = s_y
y_mask[:lengths_y[idx] + 1, idx] = 1.
return x, x_mask, y, y_mask
# beam search
def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30):
sample = []
sample_score = []
sample_alpha = []
live_k = 1
dead_k = 0
hyp_samples = [[]] * live_k
hyp_scores = np.zeros(live_k).astype(np.float32)
hyp_alpha_past = [[]] * live_k
if gpu_flag:
next_state, ctx0 = model.module.f_init(x)
else:
next_state, ctx0 = model.f_init(x)
next_w = -1 * np.ones((1,)).astype(np.int64)
next_w = torch.from_numpy(next_w).cuda()
next_alpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda()
ctx0 = ctx0.cpu().numpy()
for ii in range(maxlen):
ctx = np.tile(ctx0, [live_k, 1, 1, 1])
ctx = torch.from_numpy(ctx).cuda()
if gpu_flag:
next_p, next_state, next_alpha_past, alpha = model.module.f_next(params, next_w, None, ctx, None, next_state,
next_alpha_past, True)
else:
next_p, next_state, next_alpha_past, alpha = model.f_next(params, next_w, None, ctx, None, next_state,
next_alpha_past, True)
next_p = next_p.cpu().numpy()
next_state = next_state.cpu().numpy()
next_alpha_past = next_alpha_past.cpu().numpy()
cand_scores = hyp_scores[:, None] - np.log(next_p)
cand_flat = cand_scores.flatten()
ranks_flat = cand_flat.argsort()[:(k - dead_k)]
voc_size = next_p.shape[1]
trans_indices = ranks_flat // voc_size
word_indices = ranks_flat % voc_size
costs = cand_flat[ranks_flat]
new_hyp_samples = []
new_hyp_scores = np.zeros(k - dead_k).astype(np.float32)
new_hyp_states = []
new_hyp_alpha_past = []
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
new_hyp_samples.append(hyp_samples[ti] + [wi])
new_hyp_scores[idx] = copy.copy(costs[idx])
new_hyp_states.append(copy.copy(next_state[ti]))
new_hyp_alpha_past.append(hyp_alpha_past[ti] + [copy.copy(next_alpha_past[ti])])
#print (new_hyp_alpha_past)
new_live_k = 0
hyp_samples = []
hyp_scores = []
hyp_states = []
hyp_alpha_past = []
for idx in range(len(new_hyp_samples)):
if new_hyp_samples[idx][-1] == 0:
sample.append(new_hyp_samples[idx])
sample_score.append(new_hyp_scores[idx])
sample_alpha.append(new_hyp_alpha_past[idx])
dead_k += 1
else:
new_live_k += 1
hyp_samples.append(new_hyp_samples[idx])
hyp_scores.append(new_hyp_scores[idx])
hyp_states.append(new_hyp_states[idx])
hyp_alpha_past.append(new_hyp_alpha_past[idx])
#print (hyp_alpha_past)
hyp_scores = np.array(hyp_scores)
live_k = new_live_k
# whether finish beam search
if new_live_k < 1:
break
if dead_k >= k:
break
next_w = np.array([w[-1] for w in hyp_samples])
next_state = np.array(hyp_states)
#next_alpha_past = np.array(hyp_alpha_past)
next_alpha_past = np.array([w[-1] for w in hyp_alpha_past])
#print (np.shape(next_alpha_past))
next_w = torch.from_numpy(next_w).cuda()
next_state = torch.from_numpy(next_state).cuda()
next_alpha_past = torch.from_numpy(next_alpha_past).cuda()
return sample, sample_score, sample_alpha
# init model params
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
try:
nn.init.constant_(m.bias.data, 0.)
except:
pass
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
try:
nn.init.constant_(m.bias.data, 0.)
except:
pass