diff --git a/bin/prepare_treebank.py b/bin/prepare_treebank.py new file mode 100644 index 000000000..1de2dfdee --- /dev/null +++ b/bin/prepare_treebank.py @@ -0,0 +1,113 @@ +"""Convert OntoNotes into a json format. + +doc: { + id: string, + paragraphs: [{ + raw: string, + segmented: string, + sents: [int], + tokens: [{ + start: int, + tag: string, + head: int, + dep: string}], + brackets: [{ + start: int, + end: int, + label: string, + flabel: int}]}]} +""" +import plac +import json +from os import path +import re + +from spacy.munge import read_ptb +from spacy.munge import read_conll + + +def _iter_raw_files(raw_loc): + files = json.load(open(raw_loc)) + for f in files: + yield f + + +def _get_word_indices(raw_sent, word_idx, offset): + indices = {} + for piece in raw_sent.split(''): + for match in re.finditer(r'\S+', piece): + indices[word_idx] = offset + match.start() + word_idx += 1 + offset += len(piece) + return indices, word_idx, offset + + +def format_doc(section, filename, raw_paras, ptb_loc, dep_loc): + ptb_sents = read_ptb.split(open(ptb_loc).read()) + dep_sents = read_conll.split(open(dep_loc).read()) + + assert len(ptb_sents) == len(dep_sents) + + word_idx = 0 + offset = 0 + i = 0 + doc = {'id': 'wsj_%s%s' % (section, filename), 'paragraphs': []} + for raw_sents in raw_paras: + para = {'raw': ' '.join(sent.replace('', '') for sent in raw_sents), + 'segmented': ''.join(raw_sents), + 'sents': [], + 'tokens': [], + 'brackets': []} + for raw_sent in raw_sents: + para['sents'].append(offset) + _, brackets = read_ptb.parse(ptb_sents[i]) + _, annot = read_conll.parse(dep_sents[i]) + indices, word_idx, offset = _get_word_indices(raw_sent, 0, offset) + + for token in annot: + if token['head'] == -1: + head = indices[token['id']] + else: + head = indices[token['head']] + try: + para['tokens'].append({'start': indices[token['id']], + 'tag': token['tag'], + 'head': head, + 'dep': token['dep']}) + except: + print sorted(indices.items()) + print token + print raw_sent + raise + for label, start, end in brackets: + para['brackets'].append({'label': label, + 'start': indices[start], + 'end': indices[end-1]}) + i += 1 + doc['paragraphs'].append(para) + return doc + + +def main(onto_dir, raw_dir, out_loc): + docs = [] + for i in range(25): + section = str(i) if i >= 10 else ('0' + str(i)) + raw_loc = path.join(raw_dir, 'wsj%s.json' % section) + for j, raw_paras in enumerate(_iter_raw_files(raw_loc)): + if section == '00': + j += 1 + filename = str(j) if j >= 9 else ('0' + str(j)) + if section == '04' and filename == '55': + continue + ptb_loc = path.join(onto_dir, section, 'wsj_%s%s.parse' % (section, filename)) + dep_loc = ptb_loc + '.dep' + if path.exists(ptb_loc) and path.exists(dep_loc): + print ptb_loc + doc = format_doc(section, filename, raw_paras, ptb_loc, dep_loc) + docs.append(doc) + json.dump(docs, open(out_loc, 'w')) + + +if __name__ == '__main__': + plac.call(main) +