* Tmp commit. Working on whole document parsing

This commit is contained in:
Matthew Honnibal 2015-05-24 02:49:56 +02:00
parent 983d954ef4
commit 20f1d868a3
7 changed files with 145 additions and 79 deletions

View File

@ -1,29 +1,28 @@
"""Align the raw sentences from Read et al (2012) to the PTB tokenization,
outputing the format:
[{
section: int,
file: string,
paragraphs: [{
raw: string,
segmented: string,
tokens: [int]}]}]
outputting as a .json file. Used in bin/prepare_treebank.py
"""
import plac
from pathlib import Path
import json
from os import path
import os
from spacy.munge import read_ptb
from spacy.munge.read_ontonotes import sgml_extract
def read_unsegmented(section_loc):
def read_odc(section_loc):
# Arbitrary patches applied to the _raw_ text to promote alignment.
patches = (
('. . . .', '...'),
('....', '...'),
('Co..', 'Co.'),
("`", "'"),
# OntoNotes specific
(" S$", " US$"),
("Showtime or a sister service", "Showtime or a service"),
("The hotel and gaming company", "The hotel and Gaming company"),
("I'm-coming-down-your-throat", "I-'m coming-down-your-throat"),
)
paragraphs = []
@ -48,6 +47,7 @@ def read_ptb_sec(ptb_sec_dir):
for loc in ptb_sec_dir.iterdir():
if not str(loc).endswith('parse') and not str(loc).endswith('mrg'):
continue
filename = loc.parts[-1].split('.')[0]
with loc.open() as file_:
text = file_.read()
sents = []
@ -55,7 +55,7 @@ def read_ptb_sec(ptb_sec_dir):
words, brackets = read_ptb.parse(parse_str, strip_bad_periods=True)
words = [_reform_ptb_word(word) for word in words]
string = ' '.join(words)
sents.append(string)
sents.append((filename, string))
files.append(sents)
return files
@ -77,20 +77,36 @@ def get_alignment(raw_by_para, ptb_by_file):
# These are list-of-lists, by paragraph and file respectively.
# Flatten them into a list of (outer_id, inner_id, item) triples
raw_sents = _flatten(raw_by_para)
ptb_sents = _flatten(ptb_by_file)
assert len(raw_sents) == len(ptb_sents)
ptb_sents = list(_flatten(ptb_by_file))
output = []
for (p_id, p_sent_id, raw), (f_id, f_sent_id, ptb) in zip(raw_sents, ptb_sents):
ptb_idx = 0
n_skipped = 0
skips = []
for (p_id, p_sent_id, raw) in raw_sents:
#print raw
if ptb_idx >= len(ptb_sents):
n_skipped += 1
continue
f_id, f_sent_id, (ptb_id, ptb) = ptb_sents[ptb_idx]
alignment = align_chars(raw, ptb)
if not alignment:
skips.append((ptb, raw))
n_skipped += 1
continue
ptb_idx += 1
sepped = []
for i, c in enumerate(ptb):
if alignment[i] is False:
sepped.append('<SEP>')
else:
sepped.append(c)
output.append((f_id, p_id, f_sent_id, ''.join(sepped)))
output.append((f_id, p_id, f_sent_id, (ptb_id, ''.join(sepped))))
if n_skipped + len(ptb_sents) != len(raw_sents):
for ptb, raw in skips:
print ptb
print raw
raise Exception
return output
@ -102,6 +118,8 @@ def _flatten(nested):
def align_chars(raw, ptb):
if raw.replace(' ', '') != ptb.replace(' ', ''):
return None
i = 0
j = 0
@ -124,16 +142,20 @@ def align_chars(raw, ptb):
def group_into_files(sents):
last_id = 0
last_fn = None
this = []
output = []
for f_id, p_id, s_id, sent in sents:
for f_id, p_id, s_id, (filename, sent) in sents:
if f_id != last_id:
output.append(this)
assert last_fn is not None
output.append((last_fn, this))
this = []
last_fn = filename
this.append((f_id, p_id, s_id, sent))
last_id = f_id
if this:
output.append(this)
assert last_fn is not None
output.append((last_fn, this))
return output
@ -145,7 +167,7 @@ def group_into_paras(sents):
if p_id != last_id and this:
output.append(this)
this = []
this.append((sent))
this.append(sent)
last_id = p_id
if this:
output.append(this)
@ -161,15 +183,57 @@ def get_sections(odc_dir, ptb_dir, out_dir):
yield odc_loc, ptb_sec, out_loc
def main(odc_dir, ptb_dir, out_dir):
def do_wsj(odc_dir, ptb_dir, out_dir):
for odc_loc, ptb_sec_dir, out_loc in get_sections(odc_dir, ptb_dir, out_dir):
raw_paragraphs = read_unsegmented(odc_loc)
raw_paragraphs = read_odc(odc_loc)
ptb_files = read_ptb_sec(ptb_sec_dir)
aligned = get_alignment(raw_paragraphs, ptb_files)
files = [group_into_paras(f) for f in group_into_files(aligned)]
files = [(fn, group_into_paras(sents))
for fn, sents in group_into_files(aligned)]
with open(out_loc, 'w') as file_:
json.dump(files, file_)
def do_web(src_dir, onto_dir, out_dir):
mapping = dict(line.split() for line in open(path.join(onto_dir, 'map.txt'))
if len(line.split()) == 2)
for annot_fn, src_fn in mapping.items():
if not annot_fn.startswith('eng'):
continue
ptb_loc = path.join(onto_dir, annot_fn + '.parse')
src_loc = path.join(src_dir, src_fn + '.sgm')
if path.exists(ptb_loc) and path.exists(src_loc):
src_doc = sgml_extract(open(src_loc).read())
ptb_doc = [read_ptb.parse(parse_str, strip_bad_periods=True)[0]
for parse_str in read_ptb.split(open(ptb_loc).read())]
print 'Found'
else:
print 'Miss'
def may_mkdir(parent, *subdirs):
if not path.exists(parent):
os.mkdir(parent)
for i in range(1, len(subdirs)):
directories = (parent,) + subdirs[:i]
subdir = path.join(*directories)
if not path.exists(subdir):
os.mkdir(subdir)
def main(odc_dir, onto_dir, out_dir):
may_mkdir(out_dir, 'wsj', 'align')
may_mkdir(out_dir, 'web', 'align')
#do_wsj(odc_dir, path.join(ontonotes_dir, 'wsj', 'orig'),
# path.join(out_dir, 'wsj', 'align'))
do_web(
path.join(onto_dir, 'data', 'english', 'metadata', 'context', 'wb', 'sel'),
path.join(onto_dir, 'data', 'english', 'annotations', 'wb'),
path.join(out_dir, 'web', 'align'))
if __name__ == '__main__':
plac.call(main)

View File

@ -12,7 +12,7 @@ def parse(sent_text, strip_bad_periods=False):
words = []
id_map = {}
for i, line in enumerate(sent_text.split('\n')):
word, tag, head, dep = line.split()
word, tag, head, dep = _parse_line(line)
id_map[i] = len(words)
if strip_bad_periods and words and _is_bad_period(words[-1], word):
continue
@ -40,3 +40,10 @@ def _is_bad_period(prev, period):
return True
def _parse_line(line):
pieces = line.split()
if len(pieces) == 4:
return pieces
else:
return pieces[1], pieces[3], pieces[5], pieces[6]

View File

@ -16,7 +16,12 @@ class Scorer(object):
@property
def tags_acc(self):
return ((self.tags_corr - self.mistokened) / (self.n_tokens - self.mistokened)) * 100
return (self.tags_corr / (self.n_tokens - self.mistokened)) * 100
@property
def token_acc(self):
return (self.mistokened / self.n_tokens) * 100
@property
def uas(self):
@ -42,17 +47,18 @@ class Scorer(object):
assert len(tokens) == len(gold)
for i, token in enumerate(tokens):
if gold.orths.get(token.idx) != token.orth_:
self.mistokened += 1
if token.orth_.isspace():
continue
if not self.skip_token(i, token, gold):
self.total += 1
if verbose:
print token.orth_, token.dep_, token.head.orth_, token.head.i == gold.heads[i]
print token.orth_, token.tag_, token.dep_, token.head.orth_, token.head.i == gold.heads[i]
if token.head.i == gold.heads[i]:
self.heads_corr += 1
self.labels_corr += token.dep_ == gold.labels[i]
self.tags_corr += token.tag_ == gold.tags[i]
self.n_tokens += 1
self.labels_corr += token.dep_.lower() == gold.labels[i].lower()
if gold.tags[i] != None:
self.tags_corr += token.tag_ == gold.tags[i]
self.n_tokens += 1
gold_ents = set((start, end, label) for (start, end, label) in gold.ents)
guess_ents = set((e.start, e.end, e.label_) for e in tokens.ents)
if verbose and gold_ents:
@ -71,4 +77,4 @@ class Scorer(object):
self.ents_fp += len(guess_ents - gold_ents)
def skip_token(self, i, token, gold):
return gold.labels[i] in ('P', 'punct')
return gold.labels[i] in ('P', 'punct') and gold.heads[i] != None

View File

@ -54,7 +54,7 @@ cdef class ArcEager(TransitionSystem):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
CONSTITUENT: {}, ADJUST: {'': True}}
for raw_text, segmented, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
for raw_text, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
for child, head, label in zip(ids, heads, labels):
if label != 'ROOT':
if head < child:
@ -67,8 +67,12 @@ cdef class ArcEager(TransitionSystem):
cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length):
gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.strings[gold.labels[i]]
if gold.heads[i] is None: # Missing values
gold.c_heads[i] = i
gold.c_labels[i] = self.strings['']
else:
gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.strings[gold.labels[i]]
for end, brackets in gold.brackets.items():
for start, label_strs in brackets.items():
gold.c_brackets[start][end] = 1

View File

@ -1,6 +1,8 @@
import numpy
import codecs
import json
import random
from spacy.munge.alignment import align
from libc.string cimport memset
@ -16,19 +18,15 @@ def read_json_file(loc):
labels = []
iob_ents = []
for token in paragraph['tokens']:
#print token['start'], token['orth'], token['head'], token['dep']
words.append(token['orth'])
ids.append(token['start'])
ids.append(token['id'])
tags.append(token['tag'])
heads.append(token['head'] if token['head'] >= 0 else token['start'])
heads.append(token['head'] if token['head'] >= 0 else token['id'])
labels.append(token['dep'])
iob_ents.append(token.get('iob_ent', 'O'))
iob_ents.append(token.get('iob_ent', '-'))
brackets = []
tokenized = [s.replace('<SEP>', ' ').split(' ')
for s in paragraph['segmented'].split('<SENT>')]
paragraphs.append((paragraph['raw'],
tokenized,
(ids, words, tags, heads, labels, _iob_to_biluo(iob_ents)),
paragraph.get('brackets', [])))
return paragraphs
@ -160,39 +158,24 @@ cdef class GoldParse:
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.tags = [None] * len(tokens)
self.heads = [-1] * len(tokens)
self.labels = ['MISSING'] * len(tokens)
self.ner = ['O'] * len(tokens)
self.orths = {}
self.heads = [None] * len(tokens)
self.labels = [''] * len(tokens)
self.ner = ['-'] * len(tokens)
cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1])
gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens])
idx_map = {token.idx: token.i for token in tokens}
self.ents = []
ent_start = None
ent_label = None
for idx, orth, tag, head, label, ner in zip(*annot_tuples):
self.orths[idx] = orth
if idx < tokens[0].idx:
for i, gold_i in enumerate(cand_to_gold):
if gold_i is None:
# TODO: What do we do for missing values again?
pass
elif idx > tokens[-1].idx:
break
elif idx in idx_map:
i = idx_map[idx]
self.tags[i] = tag
self.heads[i] = idx_map.get(head, -1)
self.labels[i] = label
self.tags[i] = tag
if ner == '-':
self.ner[i] = '-'
# Deal with inconsistencies in BILUO arising from tokenization
if ner[0] in ('B', 'U', 'O') and ent_start is not None:
self.ents.append((ent_start, i, ent_label))
ent_start = None
ent_label = None
if ner[0] in ('B', 'U'):
ent_start = i
ent_label = ner[2:]
if ent_start is not None:
self.ents.append((ent_start, self.length, ent_label))
else:
self.tags[i] = annot_tuples[2][gold_i]
self.heads[i] = gold_to_cand[annot_tuples[3][gold_i]]
self.labels[i] = annot_tuples[4][gold_i]
# TODO: Declare NER information MISSING if tokenization incorrect
for start, end, label in self.ents:
if start == (end - 1):
self.ner[start] = 'U-%s' % label
@ -203,11 +186,11 @@ cdef class GoldParse:
self.ner[end-1] = 'L-%s' % label
self.brackets = {}
for (start_idx, end_idx, label_str) in brackets:
if start_idx in idx_map and end_idx in idx_map:
start = idx_map[start_idx]
end = idx_map[end_idx]
self.brackets.setdefault(end, {}).setdefault(start, set())
for (gold_start, gold_end, label_str) in brackets:
start = gold_to_cand[gold_start]
end = gold_to_cand[gold_end]
if start is not None and end is not None:
self.brackets.setdefault(start, {}).setdefault(end, set())
self.brackets[end][start].add(label)
def __len__(self):

View File

@ -73,7 +73,7 @@ cdef class BiluoPushDown(TransitionSystem):
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
OUT: {'': True}}
moves = ('M', 'B', 'I', 'L', 'U')
for (raw_text, toks, tuples, ctnt) in gold_tuples:
for (raw_text, tuples, ctnt) in gold_tuples:
ids, words, tags, heads, labels, biluo = tuples
for i, ner_tag in enumerate(biluo):
if ner_tag != 'O' and ner_tag != '-':

View File

@ -76,7 +76,9 @@ cdef class Tokenizer:
cdef bint in_ws = Py_UNICODE_ISSPACE(chars[0])
cdef UniStr span
for i in range(1, length):
if Py_UNICODE_ISSPACE(chars[i]) != in_ws:
# TODO: Allow control of hyphenation
if (Py_UNICODE_ISSPACE(chars[i]) or chars[i] == '-') != in_ws:
#if Py_UNICODE_ISSPACE(chars[i]) != in_ws:
if start < i:
slice_unicode(&span, chars, start, i)
cache_hit = self._try_cache(start, span.key, tokens)