spaCy/spacy/syntax/conll.pyx

246 lines
7.9 KiB
Cython
Raw Normal View History

2015-03-09 05:46:22 +00:00
import numpy
import codecs
2015-05-06 14:27:31 +00:00
import json
2015-03-09 05:46:22 +00:00
from libc.string cimport memset
2015-02-21 16:06:58 +00:00
2015-05-06 14:27:31 +00:00
def read_json_file(loc):
paragraphs = []
for doc in json.load(open(loc)):
for paragraph in doc['paragraphs']:
words = []
ids = []
tags = []
heads = []
labels = []
iob_ents = []
for token in paragraph['tokens']:
2015-05-07 20:53:08 +00:00
#print token['start'], token['orth'], token['head'], token['dep']
2015-05-06 14:27:31 +00:00
words.append(token['orth'])
ids.append(token['start'])
tags.append(token['tag'])
2015-05-07 20:53:08 +00:00
heads.append(token['head'] if token['head'] >= 0 else token['start'])
2015-05-06 14:27:31 +00:00
labels.append(token['dep'])
iob_ents.append(token.get('iob_ent', 'O'))
brackets = []
tokenized = [s.replace('<SEP>', ' ').split(' ')
for s in paragraph['segmented'].split('<SENT>')]
paragraphs.append((paragraph['raw'],
tokenized,
(ids, words, tags, heads, labels, _iob_to_biluo(iob_ents)),
2015-05-11 14:12:03 +00:00
paragraph.get('brackets', [])))
2015-05-06 14:27:31 +00:00
return paragraphs
def read_conll03_file(loc):
sents = []
text = codecs.open(loc, 'r', 'utf8').read().strip()
for doc in text.split('-DOCSTART- -X- O O'):
doc = doc.strip()
if not doc:
continue
for sent_str in doc.split('\n\n'):
words = []
tags = []
iob_ents = []
ids = []
lines = sent_str.strip().split('\n')
idx = 0
for line in lines:
word, tag, chunk, iob = line.split()
if tag == '"':
tag = '``'
if '|' in tag:
tag = tag.split('|')[0]
words.append(word)
tags.append(tag)
iob_ents.append(iob)
ids.append(idx)
idx += len(word) + 1
heads = [-1] * len(words)
labels = ['ROOT'] * len(words)
sents.append((' '.join(words), [words],
(ids, words, tags, heads, labels, _iob_to_biluo(iob_ents))))
return sents
2015-03-09 05:46:22 +00:00
def read_docparse_file(loc):
sents = []
for sent_str in codecs.open(loc, 'r', 'utf8').read().strip().split('\n\n'):
2015-02-21 16:06:58 +00:00
words = []
heads = []
labels = []
tags = []
ids = []
2015-03-09 05:46:22 +00:00
iob_ents = []
2015-02-21 16:06:58 +00:00
lines = sent_str.strip().split('\n')
raw_text = lines.pop(0).strip()
tok_text = lines.pop(0).strip()
for i, line in enumerate(lines):
2015-03-09 05:46:22 +00:00
id_, word, pos_string, head_idx, label, iob_ent = _parse_line(line)
2015-02-21 16:06:58 +00:00
if label == 'root':
label = 'ROOT'
words.append(word)
if head_idx < 0:
head_idx = id_
ids.append(id_)
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
2015-03-09 05:46:22 +00:00
iob_ents.append(iob_ent)
tokenized = [s.replace('<SEP>', ' ').split(' ')
for s in tok_text.split('<SENT>')]
2015-05-06 14:27:31 +00:00
tuples = (ids, words, tags, heads, labels, iob_ents)
sents.append((raw_text, tokenized, tuples, []))
2015-03-09 05:46:22 +00:00
return sents
def _iob_to_biluo(tags):
out = []
curr_label = None
tags = list(tags)
while tags:
out.extend(_consume_os(tags))
out.extend(_consume_ent(tags))
return out
def _consume_os(tags):
while tags and tags[0] == 'O':
yield tags.pop(0)
def _consume_ent(tags):
if not tags:
return []
target = tags.pop(0).replace('B', 'I')
length = 1
while tags and tags[0] == target:
length += 1
tags.pop(0)
label = target[2:]
if length == 1:
return ['U-' + label]
else:
start = 'B-' + label
end = 'L-' + label
middle = ['I-%s' % label for _ in range(1, length - 1)]
return [start] + middle + [end]
def _parse_line(line):
pieces = line.split()
if len(pieces) == 4:
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
else:
id_ = int(pieces[0])
word = pieces[1]
pos = pieces[3]
iob_ent = pieces[5]
head_idx = int(pieces[6])
label = pieces[7]
return id_, word, pos, head_idx, label, iob_ent
2015-03-09 05:46:22 +00:00
2015-03-09 05:46:22 +00:00
cdef class GoldParse:
2015-05-11 14:12:03 +00:00
def __init__(self, tokens, annot_tuples, brackets=(,)):
2015-03-09 05:46:22 +00:00
self.mem = Pool()
self.loss = 0
2015-02-23 19:04:53 +00:00
self.length = len(tokens)
2015-03-09 05:46:22 +00:00
# These are filled by the tagger/parser/entity recogniser
self.c_tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
2015-03-09 05:46:22 +00:00
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
2015-05-11 14:12:03 +00:00
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
for i in range(len(tokens)):
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
2015-03-09 05:46:22 +00:00
self.tags = [None] * len(tokens)
self.heads = [-1] * len(tokens)
self.labels = ['MISSING'] * len(tokens)
self.ner = ['O'] * len(tokens)
self.orths = {}
2015-03-09 05:46:22 +00:00
idx_map = {token.idx: token.i for token in tokens}
self.ents = []
ent_start = None
ent_label = None
for idx, orth, tag, head, label, ner in zip(*annot_tuples):
self.orths[idx] = orth
2015-03-09 05:46:22 +00:00
if idx < tokens[0].idx:
pass
elif idx > tokens[-1].idx:
break
elif idx in idx_map:
i = idx_map[idx]
self.tags[i] = tag
2015-03-09 05:46:22 +00:00
self.heads[i] = idx_map.get(head, -1)
self.labels[i] = label
self.tags[i] = tag
if ner == '-':
self.ner[i] = '-'
# Deal with inconsistencies in BILUO arising from tokenization
if ner[0] in ('B', 'U', 'O') and ent_start is not None:
self.ents.append((ent_start, i, ent_label))
ent_start = None
ent_label = None
if ner[0] in ('B', 'U'):
ent_start = i
ent_label = ner[2:]
if ent_start is not None:
self.ents.append((ent_start, self.length, ent_label))
for start, end, label in self.ents:
if start == (end - 1):
self.ner[start] = 'U-%s' % label
else:
self.ner[start] = 'B-%s' % label
for i in range(start+1, end-1):
self.ner[i] = 'I-%s' % label
self.ner[end-1] = 'L-%s' % label
2015-05-11 14:12:03 +00:00
self.brackets = {}
for (start_idx, end_idx, label_str) in brackets:
if start_idx in idx_map and end_idx in idx_map:
start = idx_map[start_idx]
end = idx_map[end_idx]
self.brackets.setdefault(end, {}).setdefault(start, set())
self.brackets[end][start].add(label)
def __len__(self):
return self.length
2015-03-09 05:46:22 +00:00
@property
def n_non_punct(self):
return len([l for l in self.labels if l not in ('P', 'punct')])
2015-03-09 05:46:22 +00:00
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
n = 0
for i in range(self.length):
if not score_punct and self.labels_[i] not in ('P', 'punct'):
2015-03-09 05:46:22 +00:00
continue
if self.heads[i] == -1:
continue
n += (i + tokens[i].head) == self.heads[i]
return n
def is_correct(self, i, head):
return head == self.c_heads[i]
2015-02-21 16:06:58 +00:00
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'
def _map_indices_to_tokens(ids, heads):
mapped = []
for head in heads:
if head not in ids:
mapped.append(None)
else:
mapped.append(ids.index(head))
return mapped