from __future__ import unicode_literals, print_function import numpy import io import json import random import re import os from os import path from libc.string cimport memset try: import ujson as json except ImportError: import json from .syntax import nonproj def tags_to_entities(tags): entities = [] start = None for i, tag in enumerate(tags): if tag.startswith('O'): # TODO: We shouldn't be getting these malformed inputs. Fix this. if start is not None: start = None continue elif tag == '-': continue elif tag.startswith('I'): assert start is not None, tags[:i] continue if tag.startswith('U'): entities.append((tag[2:], i, i)) elif tag.startswith('B'): start = i elif tag.startswith('L'): entities.append((tag[2:], start, i)) start = None else: raise Exception(tag) return entities def merge_sents(sents): m_deps = [[], [], [], [], [], []] m_brackets = [] i = 0 for (ids, words, tags, heads, labels, ner), brackets in sents: m_deps[0].extend(id_ + i for id_ in ids) m_deps[1].extend(words) m_deps[2].extend(tags) m_deps[3].extend(head + i for head in heads) m_deps[4].extend(labels) m_deps[5].extend(ner) m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets) i += len(ids) return [(m_deps, m_brackets)] def align(cand_words, gold_words): cost, edit_path = _min_edit_path(cand_words, gold_words) alignment = [] i_of_gold = 0 for move in edit_path: if move == 'M': alignment.append(i_of_gold) i_of_gold += 1 elif move == 'S': alignment.append(None) i_of_gold += 1 elif move == 'D': alignment.append(None) elif move == 'I': i_of_gold += 1 else: raise Exception(move) return alignment punct_re = re.compile(r'\W') def _min_edit_path(cand_words, gold_words): cdef: Pool mem int i, j, n_cand, n_gold int* curr_costs int* prev_costs # TODO: Fix this --- just do it properly, make the full edit matrix and # then walk back over it... # Preprocess inputs cand_words = [punct_re.sub('', w) for w in cand_words] gold_words = [punct_re.sub('', w) for w in gold_words] if cand_words == gold_words: return 0, ''.join(['M' for _ in gold_words]) mem = Pool() n_cand = len(cand_words) n_gold = len(gold_words) # Levenshtein distance, except we need the history, and we may want different # costs. # Mark operations with a string, and score the history using _edit_cost. previous_row = [] prev_costs = mem.alloc(n_gold + 1, sizeof(int)) curr_costs = mem.alloc(n_gold + 1, sizeof(int)) for i in range(n_gold + 1): cell = '' for j in range(i): cell += 'I' previous_row.append('I' * i) prev_costs[i] = i for i, cand in enumerate(cand_words): current_row = ['D' * (i + 1)] curr_costs[0] = i+1 for j, gold in enumerate(gold_words): if gold.lower() == cand.lower(): s_cost = prev_costs[j] i_cost = curr_costs[j] + 1 d_cost = prev_costs[j + 1] + 1 else: s_cost = prev_costs[j] + 1 i_cost = curr_costs[j] + 1 d_cost = prev_costs[j + 1] + (1 if cand else 0) if s_cost <= i_cost and s_cost <= d_cost: best_cost = s_cost best_hist = previous_row[j] + ('M' if gold == cand else 'S') elif i_cost <= s_cost and i_cost <= d_cost: best_cost = i_cost best_hist = current_row[j] + 'I' else: best_cost = d_cost best_hist = previous_row[j + 1] + 'D' current_row.append(best_hist) curr_costs[j+1] = best_cost previous_row = current_row for j in range(len(gold_words) + 1): prev_costs[j] = curr_costs[j] curr_costs[j] = 0 return prev_costs[n_gold], previous_row[-1] def read_json_file(loc, docs_filter=None): if path.isdir(loc): for filename in os.listdir(loc): yield from read_json_file(path.join(loc, filename)) else: with open(loc) as file_: docs = json.load(file_) for doc in docs: if docs_filter is not None and not docs_filter(doc): continue paragraphs = [] for paragraph in doc['paragraphs']: sents = [] for sent in paragraph['sentences']: words = [] ids = [] tags = [] heads = [] labels = [] ner = [] for i, token in enumerate(sent['tokens']): words.append(token['orth']) ids.append(i) tags.append(token.get('tag','-')) heads.append(token.get('head',0) + i) labels.append(token.get('dep','')) # Ensure ROOT label is case-insensitive if labels[-1].lower() == 'root': labels[-1] = 'ROOT' ner.append(token.get('ner', '-')) sents.append(( (ids, words, tags, heads, labels, ner), sent.get('brackets', []))) if sents: yield (paragraph.get('raw', None), sents) def _iob_to_biluo(tags): out = [] curr_label = None tags = list(tags) while tags: out.extend(_consume_os(tags)) out.extend(_consume_ent(tags)) return out def _consume_os(tags): while tags and tags[0] == 'O': yield tags.pop(0) def _consume_ent(tags): if not tags: return [] target = tags.pop(0).replace('B', 'I') length = 1 while tags and tags[0] == target: length += 1 tags.pop(0) label = target[2:] if length == 1: return ['U-' + label] else: start = 'B-' + label end = 'L-' + label middle = ['I-%s' % label for _ in range(1, length - 1)] return [start] + middle + [end] cdef class GoldParse: def __init__(self, tokens, annot_tuples, make_projective=False): self.mem = Pool() self.loss = 0 self.length = len(tokens) # These are filled by the tagger/parser/entity recogniser self.c.tags = self.mem.alloc(len(tokens), sizeof(int)) self.c.heads = self.mem.alloc(len(tokens), sizeof(int)) self.c.labels = self.mem.alloc(len(tokens), sizeof(int)) self.c.ner = self.mem.alloc(len(tokens), sizeof(Transition)) self.tags = [None] * len(tokens) self.heads = [None] * len(tokens) self.labels = [''] * len(tokens) self.ner = ['-'] * len(tokens) self.cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1]) self.gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens]) self.orig_annot = list(zip(*annot_tuples)) words = [w.orth_ for w in tokens] for i, gold_i in enumerate(self.cand_to_gold): if words[i].isspace(): self.tags[i] = 'SP' self.heads[i] = None self.labels[i] = None self.ner[i] = 'O' if gold_i is None: pass else: self.tags[i] = annot_tuples[2][gold_i] self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]] self.labels[i] = annot_tuples[4][gold_i] self.ner[i] = annot_tuples[5][gold_i] cycle = nonproj.contains_cycle(self.heads) if cycle != None: raise Exception("Cycle found: %s" % cycle) if make_projective: proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels) self.heads = proj_heads def __len__(self): return self.length @property def is_projective(self): return not nonproj.is_nonproj_tree(self.heads) def biluo_tags_from_offsets(doc, entities): '''Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out scheme (biluo). Arguments: doc (Doc): The document that the entity offsets refer to. The output tags will refer to the token boundaries within the document. entities (sequence): A sequence of (start, end, label) triples. start and end should be character-offset integers denoting the slice into the original string. Returns: tags (list): A list of unicode strings, describing the tags. Each tag string will be of the form either "", "O" or "{action}-{label}", where action is one of "B", "I", "L", "U". The empty string "" is used where the entity offsets don't align with the tokenization in the Doc object. The training algorithm will view these as missing values. "O" denotes a non-entity token. "B" denotes the beginning of a multi-token entity, "I" the inside of an entity of three or more tokens, and "L" the end of an entity of two or more tokens. "U" denotes a single-token entity. Example: text = 'I like London.' entities = [(len('I like '), len('I like London'), 'LOC')] doc = nlp.tokenizer(text) tags = biluo_tags_from_offsets(doc, entities) assert tags == ['O', 'O', 'U-LOC', 'O'] ''' starts = {token.idx: token.i for token in doc} ends = {token.idx+len(token): token.i for token in doc} biluo = ['' for _ in doc] # Handle entity cases for start_char, end_char, label in entities: start_token = starts.get(start_char) end_token = ends.get(end_char) # Only interested if the tokenization is correct if start_token is not None and end_token is not None: if start_token == end_token: biluo[start_token] = 'U-%s' % label else: biluo[start_token] = 'B-%s' % label for i in range(start_token+1, end_token): biluo[i] = 'I-%s' % label biluo[end_token] = 'L-%s' % label # Now distinguish the O cases from ones where we miss the tokenization entity_chars = set() for start_char, end_char, label in entities: for i in range(start_char, end_char): entity_chars.add(i) for token in doc: for i in range(token.idx, token.idx+len(token)): if i in entity_chars: break else: biluo[token.i] = 'O' return biluo def is_punct_label(label): return label == 'P' or label.lower() == 'punct'