350 lines
12 KiB
Python
350 lines
12 KiB
Python
|
"""
|
||
|
Copyright (c) 2019-present NAVER Corp.
|
||
|
MIT License
|
||
|
"""
|
||
|
|
||
|
# -*- coding: cp932 -*-
|
||
|
import sys
|
||
|
import os
|
||
|
import time
|
||
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
from torch.autograd import Variable
|
||
|
|
||
|
from PIL import Image, ImageDraw, ImageFont
|
||
|
from utils import dataIterator, load_dict, gen_sample, load_mapping
|
||
|
from encoder_decoder import Encoder_Decoder
|
||
|
|
||
|
import cv2
|
||
|
from skimage import io
|
||
|
import numpy as np
|
||
|
import craft_utils
|
||
|
import imgproc
|
||
|
import file_utils
|
||
|
import json
|
||
|
import zipfile
|
||
|
import xml.etree.cElementTree as ET
|
||
|
import xml.dom.minidom as minidom
|
||
|
import codecs
|
||
|
from craft import CRAFT
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
def copyStateDict(state_dict):
|
||
|
if list(state_dict.keys())[0].startswith("module"):
|
||
|
start_idx = 1
|
||
|
else:
|
||
|
start_idx = 0
|
||
|
new_state_dict = OrderedDict()
|
||
|
for k, v in state_dict.items():
|
||
|
name = ".".join(k.split(".")[start_idx:])
|
||
|
new_state_dict[name] = v
|
||
|
return new_state_dict
|
||
|
|
||
|
def str2bool(v):
|
||
|
return v.lower() in ("yes", "y", "true", "t", "1")
|
||
|
|
||
|
def pil2cv(imgPIL):
|
||
|
imgCV_RGB = np.array(imgPIL, dtype = np.uint8)
|
||
|
imgCV_BGR = np.array(imgPIL)[:, :, ::-1]
|
||
|
return imgCV_BGR
|
||
|
|
||
|
def cv2pil(imgCV):
|
||
|
imgCV_RGB = imgCV[:, :, ::-1]
|
||
|
imgPIL = Image.fromarray(imgCV_RGB)
|
||
|
return imgPIL
|
||
|
def cv2_putChar(draw, char, x, y, fontPIL, colorRGB):
|
||
|
draw.text(xy = (x,y), text = char, fill = colorRGB, font = fontPIL)
|
||
|
|
||
|
def cv2_putText_1(img, text, org, fontFace, fontScale, color):
|
||
|
min_x, max_x, min_y, max_y = org
|
||
|
|
||
|
imgPIL = cv2pil(img)
|
||
|
draw = ImageDraw.Draw(imgPIL)
|
||
|
fontPIL = ImageFont.truetype(font = fontFace, size = fontScale)
|
||
|
if max_x - min_x >= max_y- min_y:
|
||
|
#horizontal line
|
||
|
y = max_y
|
||
|
x = min_x
|
||
|
for char in text:
|
||
|
cv2_putChar(draw, char, x, y, fontPIL, color )
|
||
|
w, h = draw.textsize(char, font = fontPIL)
|
||
|
x += w + 10
|
||
|
else:
|
||
|
#vertical line
|
||
|
y = min_y
|
||
|
x = max_x - 10
|
||
|
for char in text:
|
||
|
cv2_putChar(draw, char, x, y, fontPIL, color )
|
||
|
w, h = draw.textsize(char, font = fontPIL)
|
||
|
y += h + 10
|
||
|
imgCV = pil2cv(imgPIL)
|
||
|
return imgCV
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='Kindai document Recognition')
|
||
|
#params for text detection
|
||
|
parser.add_argument('--trained_model', default='./pretrain/synweights_4600.pth', type=str, help='pretrained model')
|
||
|
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
|
||
|
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
|
||
|
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
|
||
|
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
|
||
|
parser.add_argument('--canvas_size', default=1000, type=int, help='image size for inference')
|
||
|
parser.add_argument('--mag_ratio', default=2, type=float, help='image magnification ratio')
|
||
|
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
|
||
|
parser.add_argument('--show_time', default=True, action='store_true', help='show processing time')
|
||
|
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
|
||
|
|
||
|
#params for text recognition
|
||
|
parser.add_argument('--model_path', default='./pretrain/WAP_params.pkl', type=str)
|
||
|
parser.add_argument('--dictionary_target', default='./pretrain/kindai_voc.txt', type=str)
|
||
|
|
||
|
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
|
||
|
""" For test images in a folder """
|
||
|
image_list, _, _ = file_utils.get_files('./data/test')
|
||
|
|
||
|
result_folder = './data/result1/'
|
||
|
if not os.path.isdir(result_folder):
|
||
|
os.mkdir(result_folder)
|
||
|
|
||
|
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
|
||
|
t0 = time.time()
|
||
|
|
||
|
# resize
|
||
|
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
|
||
|
ratio_h = ratio_w = 1 / target_ratio
|
||
|
|
||
|
# preprocessing
|
||
|
x = imgproc.normalizeMeanVariance(img_resized)
|
||
|
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
|
||
|
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
|
||
|
if cuda:
|
||
|
x = x.cuda()
|
||
|
# forward pass
|
||
|
y, _ = net(x)
|
||
|
|
||
|
# make score and link map
|
||
|
score_text = y[0,:,:,0].cpu().data.numpy()
|
||
|
score_link = y[0,:,:,1].cpu().data.numpy()
|
||
|
|
||
|
t0 = time.time() - t0
|
||
|
t1 = time.time()
|
||
|
|
||
|
# Post-processing
|
||
|
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
|
||
|
|
||
|
# coordinate adjustment
|
||
|
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
||
|
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
|
||
|
for k in range(len(polys)):
|
||
|
if polys[k] is None: polys[k] = boxes[k]
|
||
|
|
||
|
t1 = time.time() - t1
|
||
|
|
||
|
# render results (optional)
|
||
|
render_img = score_text.copy()
|
||
|
render_img = np.hstack((render_img, score_link))
|
||
|
ret_score_text = imgproc.cvt2HeatmapImg(render_img)
|
||
|
|
||
|
if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
|
||
|
|
||
|
return boxes, polys, ret_score_text
|
||
|
|
||
|
|
||
|
|
||
|
def test(text_detection_modelpara, ocr_modelpara, dictionary_target):
|
||
|
# load net
|
||
|
net = CRAFT() # initialize
|
||
|
|
||
|
print('Loading text detection model from checkpoint {}'.format(text_detection_modelpara))
|
||
|
if args.cuda:
|
||
|
net.load_state_dict(copyStateDict(torch.load(text_detection_modelpara)))
|
||
|
else:
|
||
|
net.load_state_dict(copyStateDict(torch.load(text_detection_modelpara, map_location='cpu')))
|
||
|
|
||
|
if args.cuda:
|
||
|
net = net.cuda()
|
||
|
net = torch.nn.DataParallel(net)
|
||
|
cudnn.benchmark = False
|
||
|
|
||
|
|
||
|
params = {}
|
||
|
params['n'] = 256
|
||
|
params['m'] = 256
|
||
|
params['dim_attention'] = 512
|
||
|
params['D'] = 684
|
||
|
params['K'] = 5748
|
||
|
params['growthRate'] = 24
|
||
|
params['reduction'] = 0.5
|
||
|
params['bottleneck'] = True
|
||
|
params['use_dropout'] = True
|
||
|
params['input_channels'] = 3
|
||
|
params['cuda'] = args.cuda
|
||
|
|
||
|
# load model
|
||
|
OCR = Encoder_Decoder(params)
|
||
|
if args.cuda:
|
||
|
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara)))
|
||
|
else:
|
||
|
OCR.load_state_dict(copyStateDict(torch.load(ocr_modelpara, map_location='cpu')))
|
||
|
if args.cuda:
|
||
|
#OCR = OCR.cuda()
|
||
|
OCR = torch.nn.DataParallel(OCR)
|
||
|
cudnn.benchmark = False
|
||
|
|
||
|
|
||
|
OCR.eval()
|
||
|
net.eval()
|
||
|
|
||
|
# load dictionary
|
||
|
worddicts = load_dict(dictionary_target)
|
||
|
worddicts_r = [None] * len(worddicts)
|
||
|
for kk, vv in worddicts.items():
|
||
|
worddicts_r[vv] = kk
|
||
|
t = time.time()
|
||
|
|
||
|
fontPIL = '/usr/share/fonts/truetype/fonts-japanese-gothic.ttf' # japanese font
|
||
|
size = 40
|
||
|
colorBGR = (0,0,255)
|
||
|
|
||
|
|
||
|
paper = ET.Element('paper')
|
||
|
paper.set('xmlns', "http://codh.rois.ac.jp/modern-magazine/")
|
||
|
# load data
|
||
|
for k, image_path in enumerate(image_list[:]):
|
||
|
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
|
||
|
res_img_file = result_folder + "res_" + os.path.basename(image_path)
|
||
|
|
||
|
#print (res_img_file, os.path.basename(image_path), os.path.exists(res_img_file))
|
||
|
#if os.path.exists(res_img_file): continue
|
||
|
#image = imgproc.loadImage(image_path)
|
||
|
'''image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||
|
ret2,image = cv2.threshold(image,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||
|
height = image.shape[0]
|
||
|
width = image.shape[1]
|
||
|
scale = 1000.0/height
|
||
|
H = int(image.shape[0] * scale)
|
||
|
W = int(image.shape[1] * scale)
|
||
|
image = cv2.resize(image , (W, H))
|
||
|
print(image.shape, image_path)
|
||
|
cv2.imwrite(image_path, image)
|
||
|
continue'''
|
||
|
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
h, w = image.shape[0], image.shape[1]
|
||
|
print(image_path)
|
||
|
page = ET.SubElement(paper, "page")
|
||
|
page.set('file', os.path.basename(image_path).replace('.jpg', ''))
|
||
|
page.set('height', str(h))
|
||
|
page.set('width', str(w))
|
||
|
page.set('dpi', str(100))
|
||
|
page.set('number', str(1))
|
||
|
|
||
|
bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
|
||
|
text = []
|
||
|
localtions = []
|
||
|
for i, box in enumerate(bboxes):
|
||
|
poly = np.array(box).astype(np.int32)
|
||
|
min_x = np.min(poly[:,0])
|
||
|
max_x = np.max(poly[:,0])
|
||
|
min_y = np.min(poly[:,1])
|
||
|
max_y = np.max(poly[:,1])
|
||
|
if min_x < 0:
|
||
|
min_x = 0
|
||
|
if min_y < 0:
|
||
|
min_y = 0
|
||
|
|
||
|
#image = cv2.rectangle(image,(min_x,min_y),(max_x,max_y),(0,255,0),3)
|
||
|
input_img = image[min_y:max_y, min_x:max_x]
|
||
|
|
||
|
w = max_x - min_x + 1
|
||
|
h = max_y - min_y + 1
|
||
|
line = ET.SubElement(page, "line")
|
||
|
line.set("x", str(min_x))
|
||
|
line.set("y", str(min_y))
|
||
|
line.set("height", str(h))
|
||
|
line.set("width", str(w))
|
||
|
if w < h:
|
||
|
rate = 20.0/w
|
||
|
w = int(round(w*rate))
|
||
|
h = int(round(h* rate / 20.0) * 20)
|
||
|
else:
|
||
|
rate = 20.0/h
|
||
|
w = int(round(w*rate / 20.0) * 20)
|
||
|
h = int(round(h* rate))
|
||
|
#print (w, h, rate)
|
||
|
input_img = cv2.resize(input_img, (w,h))
|
||
|
|
||
|
mat = np.zeros([1, h, w], dtype='uint8')
|
||
|
mat[0,:,:] = 0.299* input_img[:, :, 0] + 0.587 * input_img[:, :, 1] + 0.114 * input_img[:, :, 2]
|
||
|
|
||
|
xx_pad = mat.astype(np.float32) / 255.
|
||
|
xx_pad = torch.from_numpy(xx_pad[None, :, :, :]) # (1,1,H,W)
|
||
|
if args.cuda:
|
||
|
xx_pad.cuda()
|
||
|
with torch.no_grad():
|
||
|
sample, score, alpha_past_list = gen_sample(OCR, xx_pad, params, args.cuda, k=10, maxlen=600)
|
||
|
score = score / np.array([len(s) for s in sample])
|
||
|
ss = sample[score.argmin()]
|
||
|
alpha_past = alpha_past_list[score.argmin()]
|
||
|
result = ''
|
||
|
i = 0
|
||
|
location = []
|
||
|
for vv in ss:
|
||
|
|
||
|
if vv == 0: # <eol>
|
||
|
break
|
||
|
alpha = alpha_past[i]
|
||
|
if i != 0: alpha = alpha_past[i] - alpha_past[i-1]
|
||
|
(y, x) = np.unravel_index(np.argmax(alpha, axis=None), alpha.shape)
|
||
|
#print (int(16* x /rate), int(16* y/rate) , chr(int(worddicts_r[vv],16)))
|
||
|
location.append([int(16* x/rate) + min_x, int(16* y/rate) + min_y])
|
||
|
#image = cv2.circle(image,(int(16* x/rate) - 8 + min_x, int(16* y/rate) + 8 + min_y),25, (0,0,255), -1)
|
||
|
|
||
|
|
||
|
result += chr(int(worddicts_r[vv],16))
|
||
|
'''char = ET.SubElement(line, "char")
|
||
|
char.set('num_cand', '1')
|
||
|
char.set('x', str(int(16* x/rate) - 8 + min_x))
|
||
|
char.set('y', str(int(16* y/rate) + 8 + min_y))
|
||
|
res = ET.SubElement(char, "result")
|
||
|
res.set('CC', str(100))
|
||
|
res.text = chr(int(worddicts_r[vv],16))
|
||
|
cand = ET.SubElement(char, "cand")
|
||
|
cand.set('CC', str(100))
|
||
|
cand.text = chr(int(worddicts_r[vv],16))'''
|
||
|
i+=1
|
||
|
line.text = result
|
||
|
text.append(result)
|
||
|
localtions.append(location)
|
||
|
image = cv2_putText_1(img = image, text = result, org = (min_x, max_x, min_y, max_y), fontFace = fontPIL, fontScale = size, color = colorBGR)
|
||
|
|
||
|
|
||
|
print('save image')
|
||
|
# save score text
|
||
|
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|
||
|
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
|
||
|
#cv2.imwrite(mask_file, score_text)
|
||
|
file_utils.saveResult(image_path, image, polys, dirname=result_folder)
|
||
|
|
||
|
xml_string = ET.tostring(paper, 'Shift_JIS')
|
||
|
|
||
|
fout = codecs.open('./data/result.xml', 'w', 'shift_jis')
|
||
|
fout.write(xml_string.decode('shift_jis'))
|
||
|
fout.close()
|
||
|
|
||
|
|
||
|
print("elapsed time : {}s".format(time.time() - t))
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test(args.trained_model, args.model_path, args.dictionary_target)
|