add code for both GPU and CPU
This commit is contained in:
parent
7ff9246af5
commit
31e96013c5
File diff suppressed because one or more lines are too long
Binary file not shown.
Before Width: | Height: | Size: 4.0 MiB After Width: | Height: | Size: 4.0 MiB |
21
decoder.py
21
decoder.py
|
@ -6,6 +6,7 @@ import torch.nn as nn
|
||||||
class Gru_cond_layer(nn.Module):
|
class Gru_cond_layer(nn.Module):
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
super(Gru_cond_layer, self).__init__()
|
super(Gru_cond_layer, self).__init__()
|
||||||
|
self.cuda = params['cuda']
|
||||||
# attention
|
# attention
|
||||||
self.conv_Ua = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
|
self.conv_Ua = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
|
||||||
self.fc_Wa = nn.Linear(params['n'], params['dim_attention'], bias=False)
|
self.fc_Wa = nn.Linear(params['n'], params['dim_attention'], bias=False)
|
||||||
|
@ -44,16 +45,24 @@ class Gru_cond_layer(nn.Module):
|
||||||
|
|
||||||
if one_step:
|
if one_step:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.ones(embedding.shape[0]).cuda()
|
mask = torch.ones(embedding.shape[0])
|
||||||
|
if self.cuda:
|
||||||
|
mask.cuda()
|
||||||
h2ts, cts, alphas, alpha_pasts = self._step_slice(mask, state_below_r, state_below_z, state_below_h,
|
h2ts, cts, alphas, alpha_pasts = self._step_slice(mask, state_below_r, state_below_z, state_below_h,
|
||||||
init_state, context, context_mask, alpha_past, Ua_ctx)
|
init_state, context, context_mask, alpha_past, Ua_ctx)
|
||||||
else:
|
else:
|
||||||
alpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3]).cuda()
|
alpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3])
|
||||||
h2t = init_state
|
h2t = init_state
|
||||||
h2ts = torch.zeros(n_steps, n_samples, params['n']).cuda()
|
h2ts = torch.zeros(n_steps, n_samples, params['n'])
|
||||||
cts = torch.zeros(n_steps, n_samples, params['D']).cuda()
|
cts = torch.zeros(n_steps, n_samples, params['D'])
|
||||||
alphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])).cuda()
|
alphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]))
|
||||||
alpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]).cuda()
|
alpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])
|
||||||
|
if self.cuda:
|
||||||
|
alpha_past.cuda()
|
||||||
|
h2ts.cuda()
|
||||||
|
cts.cuda()
|
||||||
|
alphas.cuda()
|
||||||
|
alpha_pasts.cuda()
|
||||||
for i in range(n_steps):
|
for i in range(n_steps):
|
||||||
h2t, ct, alpha, alpha_past = self._step_slice(mask[i], state_below_r[i], state_below_z[i],
|
h2t, ct, alpha, alpha_past = self._step_slice(mask[i], state_below_r[i], state_below_z[i],
|
||||||
state_below_h[i], h2t, context, context_mask, alpha_past,
|
state_below_h[i], h2t, context, context_mask, alpha_past,
|
||||||
|
|
|
@ -21,14 +21,19 @@ class My_Embedding(nn.Module):
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
super(My_Embedding, self).__init__()
|
super(My_Embedding, self).__init__()
|
||||||
self.embedding = nn.Embedding(params['K'], params['m'])
|
self.embedding = nn.Embedding(params['K'], params['m'])
|
||||||
|
self.cuda = params['cuda']
|
||||||
|
|
||||||
def forward(self, params, y):
|
def forward(self, params, y):
|
||||||
if y.sum() < 0.:
|
if y.sum() < 0.:
|
||||||
emb = torch.zeros(1, params['m']).cuda()
|
emb = torch.zeros(1, params['m'])
|
||||||
|
if self.cuda:
|
||||||
|
emb.cuda()
|
||||||
else:
|
else:
|
||||||
emb = self.embedding(y)
|
emb = self.embedding(y)
|
||||||
if len(emb.shape) == 3: # only for training stage
|
if len(emb.shape) == 3: # only for training stage
|
||||||
emb_shifted = torch.zeros([emb.shape[0], emb.shape[1], params['m']], dtype=torch.float32).cuda()
|
emb_shifted = torch.zeros([emb.shape[0], emb.shape[1], params['m']], dtype=torch.float32)
|
||||||
|
if self.cuda:
|
||||||
|
emb_shifted.cuda()
|
||||||
emb_shifted[1:] = emb[:-1]
|
emb_shifted[1:] = emb[:-1]
|
||||||
emb = emb_shifted
|
emb = emb_shifted
|
||||||
return emb
|
return emb
|
||||||
|
@ -43,6 +48,7 @@ class Encoder_Decoder(nn.Module):
|
||||||
self.emb_model = My_Embedding(params)
|
self.emb_model = My_Embedding(params)
|
||||||
self.gru_model = Gru_cond_layer(params)
|
self.gru_model = Gru_cond_layer(params)
|
||||||
self.gru_prob_model = Gru_prob(params)
|
self.gru_prob_model = Gru_prob(params)
|
||||||
|
self.cuda = params['cuda']
|
||||||
|
|
||||||
def forward(self, params, x, x_mask, y, y_mask, one_step=False):
|
def forward(self, params, x, x_mask, y, y_mask, one_step=False):
|
||||||
# recover permute
|
# recover permute
|
||||||
|
|
|
@ -41,7 +41,7 @@ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
print (np.shape(img))
|
#print (np.shape(img))
|
||||||
result = np.zeros((np.shape(img)[0], np.shape(img)[1]))
|
result = np.zeros((np.shape(img)[0], np.shape(img)[1]))
|
||||||
# make result file list
|
# make result file list
|
||||||
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
Python==3.7.7
|
||||||
torch==1.4.0
|
torch==1.4.0
|
||||||
torchvision==0.2.1
|
torchvision==0.2.1
|
||||||
opencv-python==3.4.2.17
|
opencv-python==3.4.2.17
|
||||||
scikit-image==0.14.2
|
scikit-image==0.14.2
|
||||||
scipy==1.1.0
|
scipy==1.1.0
|
||||||
Polygon3
|
Polygon3
|
||||||
|
pillow==4.3.0
|
||||||
|
|
||||||
|
|
12
test.py
12
test.py
|
@ -187,6 +187,7 @@ def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
|
||||||
params['bottleneck'] = True
|
params['bottleneck'] = True
|
||||||
params['use_dropout'] = True
|
params['use_dropout'] = True
|
||||||
params['input_channels'] = 3
|
params['input_channels'] = 3
|
||||||
|
params['cuda'] = args.cuda
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
OCR = Encoder_Decoder(params)
|
OCR = Encoder_Decoder(params)
|
||||||
|
@ -194,9 +195,8 @@ def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
|
||||||
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara)))
|
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara)))
|
||||||
else:
|
else:
|
||||||
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara, map_location='cpu')))
|
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara, map_location='cpu')))
|
||||||
|
|
||||||
if args.cuda:
|
if args.cuda:
|
||||||
OCR = OCR.cuda()
|
#OCR = OCR.cuda()
|
||||||
OCR = torch.nn.DataParallel(OCR)
|
OCR = torch.nn.DataParallel(OCR)
|
||||||
cudnn.benchmark = False
|
cudnn.benchmark = False
|
||||||
|
|
||||||
|
@ -288,9 +288,11 @@ def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
|
||||||
mat[0,:,:] = 0.299* input_img[:, :, 0] + 0.587 * input_img[:, :, 1] + 0.114 * input_img[:, :, 2]
|
mat[0,:,:] = 0.299* input_img[:, :, 0] + 0.587 * input_img[:, :, 1] + 0.114 * input_img[:, :, 2]
|
||||||
|
|
||||||
xx_pad = mat.astype(np.float32) / 255.
|
xx_pad = mat.astype(np.float32) / 255.
|
||||||
xx_pad = torch.from_numpy(xx_pad[None, :, :, :]).cuda() # (1,1,H,W)
|
xx_pad = torch.from_numpy(xx_pad[None, :, :, :]) # (1,1,H,W)
|
||||||
|
if args.cuda:
|
||||||
|
xx_pad.cuda()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sample, score, alpha_past_list = gen_sample(OCR, xx_pad, params, True, k=10, maxlen=600)
|
sample, score, alpha_past_list = gen_sample(OCR, xx_pad, params, args.cuda, k=10, maxlen=600)
|
||||||
score = score / np.array([len(s) for s in sample])
|
score = score / np.array([len(s) for s in sample])
|
||||||
ss = sample[score.argmin()]
|
ss = sample[score.argmin()]
|
||||||
alpha_past = alpha_past_list[score.argmin()]
|
alpha_past = alpha_past_list[score.argmin()]
|
||||||
|
@ -327,7 +329,7 @@ def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
|
||||||
image = cv2_putText_1(img = image, text = result, org = (min_x, max_x, min_y, max_y), fontFace = fontPIL, fontScale = size, color = colorBGR)
|
image = cv2_putText_1(img = image, text = result, org = (min_x, max_x, min_y, max_y), fontFace = fontPIL, fontScale = size, color = colorBGR)
|
||||||
|
|
||||||
|
|
||||||
|
print('save image')
|
||||||
# save score text
|
# save score text
|
||||||
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|
||||||
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
|
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
|
||||||
|
|
39
utils.py
39
utils.py
|
@ -114,23 +114,7 @@ def load_mapping(dictFile):
|
||||||
|
|
||||||
# create batch
|
# create batch
|
||||||
def prepare_data(options, images_x, seqs_y, prev_x = None):
|
def prepare_data(options, images_x, seqs_y, prev_x = None):
|
||||||
'''if prev_x!= None and len(images_x) == len(prev_x) and np.random.random_sample() > 0.7:
|
|
||||||
for i in range(len(images_x)):
|
|
||||||
#print(np.shape(images_x[i]))
|
|
||||||
images_x[i] = images_x[i]*0.7 + prev_x[i] * 0.3'''
|
|
||||||
|
|
||||||
'''if np.random.random_sample() > 0.7:
|
|
||||||
for i in range(len(images_x)):
|
|
||||||
#print(np.shape(images_x[i][0]))
|
|
||||||
if np.shape(images_x[i][0])[0] <= 100 or np.shape(images_x[i][0])[1] <= 100: continue
|
|
||||||
img = Image.fromarray(images_x[i][0])
|
|
||||||
img = sk.perform_operation([img])
|
|
||||||
img = br.perform_operation([img[0]])
|
|
||||||
#img[0].save(str(i) + '.jpg')
|
|
||||||
img = np.asarray(img[0])
|
|
||||||
images_x[i][0] = img
|
|
||||||
#print(np.shape(images_x[i][0]))'''
|
|
||||||
|
|
||||||
heights_x = [s.shape[1] for s in images_x]
|
heights_x = [s.shape[1] for s in images_x]
|
||||||
widths_x = [s.shape[2] for s in images_x]
|
widths_x = [s.shape[2] for s in images_x]
|
||||||
lengths_y = [len(s) for s in seqs_y]
|
lengths_y = [len(s) for s in seqs_y]
|
||||||
|
@ -167,14 +151,23 @@ def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30):
|
||||||
else:
|
else:
|
||||||
next_state, ctx0 = model.f_init(x)
|
next_state, ctx0 = model.f_init(x)
|
||||||
next_w = -1 * np.ones((1,)).astype(np.int64)
|
next_w = -1 * np.ones((1,)).astype(np.int64)
|
||||||
next_w = torch.from_numpy(next_w).cuda()
|
next_w = torch.from_numpy(next_w)
|
||||||
next_alpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda()
|
next_alpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3])
|
||||||
ctx0 = ctx0.cpu().numpy()
|
ctx0 = ctx0.cpu().numpy()
|
||||||
|
|
||||||
|
if gpu_flag:
|
||||||
|
next_w.cuda()
|
||||||
|
next_alpha_past.cuda()
|
||||||
|
|
||||||
for ii in range(maxlen):
|
for ii in range(maxlen):
|
||||||
ctx = np.tile(ctx0, [live_k, 1, 1, 1])
|
ctx = np.tile(ctx0, [live_k, 1, 1, 1])
|
||||||
ctx = torch.from_numpy(ctx).cuda()
|
ctx = torch.from_numpy(ctx)
|
||||||
if gpu_flag:
|
if gpu_flag:
|
||||||
|
ctx.cuda()
|
||||||
|
next_w.cuda()
|
||||||
|
next_state.cuda()
|
||||||
|
next_alpha_past.cuda()
|
||||||
|
|
||||||
next_p, next_state, next_alpha_past, alpha = model.module.f_next(params, next_w, None, ctx, None, next_state,
|
next_p, next_state, next_alpha_past, alpha = model.module.f_next(params, next_w, None, ctx, None, next_state,
|
||||||
next_alpha_past, True)
|
next_alpha_past, True)
|
||||||
else:
|
else:
|
||||||
|
@ -235,9 +228,9 @@ def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30):
|
||||||
#next_alpha_past = np.array(hyp_alpha_past)
|
#next_alpha_past = np.array(hyp_alpha_past)
|
||||||
next_alpha_past = np.array([w[-1] for w in hyp_alpha_past])
|
next_alpha_past = np.array([w[-1] for w in hyp_alpha_past])
|
||||||
#print (np.shape(next_alpha_past))
|
#print (np.shape(next_alpha_past))
|
||||||
next_w = torch.from_numpy(next_w).cuda()
|
next_w = torch.from_numpy(next_w)
|
||||||
next_state = torch.from_numpy(next_state).cuda()
|
next_state = torch.from_numpy(next_state)
|
||||||
next_alpha_past = torch.from_numpy(next_alpha_past).cuda()
|
next_alpha_past = torch.from_numpy(next_alpha_past)
|
||||||
return sample, sample_score, sample_alpha
|
return sample, sample_score, sample_alpha
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue