From 4a3babcb10feb2c7ea06bbe1d98a7c8c16716bbf Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 28 Oct 2015 17:36:11 +0100 Subject: [PATCH] * Add displacy service --- services/displacy.py | 340 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 340 insertions(+) create mode 100644 services/displacy.py diff --git a/services/displacy.py b/services/displacy.py new file mode 100644 index 000000000..40451daf2 --- /dev/null +++ b/services/displacy.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python +from __future__ import unicode_literals +from __future__ import print_function +import sys + +import falcon +import json +from os import path +from collections import defaultdict +import pprint +import numpy + +import spacy.en +from spacy.attrs import ORTH, SPACY, TAG, POS, ENT_IOB, ENT_TYPE +from spacy.parts_of_speech import NAMES as UNIV_POS_NAMES + +try: + unicode +except NameError: + unicode = str + + +NLU = spacy.en.English() + + +def merge_entities(doc): + ents = [(e[0].idx, e[len(e)-1].idx + len(e[len(e)-1]), e.label_, e.text) + for e in doc.ents if len(e) >= 2] + for start, end, label, lemma in ents: + merged = doc.merge(start, end, label, text, label) + assert merged != None + + +def merge_nps(doc): + nps = [(np[0].idx, np[-1].idx + len(np[-1]), np.root.tag_, np.text) + for np in doc.noun_chunks if len(np) >= 2] + + for start, end, ent_type, lemma in nps: + doc.merge(start, end, u'NP', lemma, ent_type) + + +def merge_punct(tokens): + # Merge punctuation onto its head + collect = False + start = None + merges = [] + + for word in tokens: + if word.whitespace_: + if collect: + span = tokens[start:word.i+1] + if len(span) >= 2: + merges.append(( + span[0].idx, + span[-1].idx + len(span[-1]), + span.root.tag_, + span.root.lemma_, + span.root.ent_type_)) + collect = False + start = None + elif not collect: + collect = True + start = word.i + if collect: + span = tokens[start:len(tokens)] + merges.append((span[0].idx, span[-1].idx + len(span[-1]), + span.root.tag_, span.root.lemma_, span.root.ent_type_)) + for merge in merges: + tokens.merge(*merge) + + +def get_actions(parse_state, n_actions): + actions = [] + queue = list(sorted(parse_state.queue)) + stack = list(sorted(parse_state.stack)) + stack = [] + actions.append({'label': 'shift', 'key': 'S', 'binding': 38, + 'is_valid': NLU.parser.moves.is_valid(parse_state, 'S')}) + actions.append({'label': 'left', 'key': 'L', 'binding': 37, + 'is_valid': NLU.parser.moves.is_valid(parse_state, 'L-det')}) + actions.append({'label': 'predict', 'key': '_', 'binding': 32, + 'is_valid': bool(parse_state.queue or parse_state.stack)}) + actions.append({'label': 'right', 'key': 'R', 'binding': 39, + 'is_valid': NLU.parser.moves.is_valid(parse_state, 'R-dobj')}) + actions.append({'label': 'undo', 'key': '-', 'binding': 8, + 'is_valid': n_actions != 0}) + actions.append({'label': 'reduce', 'key': 'D', 'binding': 40, + 'is_valid': NLU.parser.moves.is_valid(parse_state, 'D')}) + return actions + + +class Model(object): + def to_json(self): + return {name: _as_json(value) for name, value in self.__dict__.items() + if not name.startswith('_')} + +def _as_json(value): + if hasattr(value, 'to_json'): + return value.to_json() + elif isinstance(value, list): + return [_as_json(v) for v in value] + elif isinstance(value, set): + return {key: True for key in value} + else: + return value + + +def _parse_history(history): + if history and history.endswith(','): + history = history[:-1] + history = history.strip().split(',') if history else tuple() + new_hist = [] + history_length = len(history) + for action in history: + if action == '-': + if new_hist: + new_hist.pop() + else: + new_hist.append(action) + return new_hist, history_length + + +def apply_edits(tokens, word_edits, tag_edits): + new_words = [] + attrs = (POS, ENT_TYPE, ENT_IOB) + new_analysis = numpy.zeros(shape=(len(tokens), len(attrs)), dtype=numpy.int32) + for word in tokens: + key = str(word.i) + new_words.append(word_edits.get(key, word.orth_)) + tag = tag_edits.get(key, word.pos_) + if tag in UNIV_POS_NAMES: + new_analysis[word.i, 0] = UNIV_POS_NAMES[tag] + # Set ent_type=0 and IOB="O" + new_analysis[word.i, 1] = 0 + new_analysis[word.i, 2] = 2 + else: + new_analysis[word.i, 0] = word.pos + new_analysis[word.i, 1] = NLU.vocab.strings[tag] + new_analysis[word.i, 2] = 3 + + doc = NLU.tokenizer.tokens_from_list(new_words) + doc.from_array(attrs, new_analysis) + NLU.parser(doc) + return doc + + +class Parse(Model): + def __init__(self, doc, states, actions, **kwargs): + word_edits = kwargs.get('words', {}) + tag_edits = kwargs.get('tags', {}) + if word_edits or tag_edits: + doc = apply_edits(doc, word_edits, tag_edits) + notes = kwargs.get('notes', {}) + self.actions = actions + self.words = [Word(w, w.i in word_edits, w.i in tag_edits) for w in doc] + self.states = states + self.notes = notes + for word in doc: + print(word.orth_, word.head.orth_) + + @classmethod + def from_text(cls, text, **kwargs): + tokens = NLU(text) + #merge_entities(tokens) + merge_nps(tokens) + #merge_punct(tokens) + return cls(tokens, [State.from_doc(tokens)], [], **kwargs) + + @classmethod + def from_history(cls, text, history, **kwargs): + if not isinstance(text, unicode): + text = text.decode('utf8') + text = text.replace('-SLASH-', '/') + history, history_length = _parse_history(history) + + tokens = NLU.tokenizer(text) + NLU.tagger(tokens) + NLU.matcher(tokens) + + with NLU.parser.step_through(tokens) as state: + for action in history: + state.transition(action) + + NLU.entity(tokens) + actions = get_actions(state.stcls, len(history)) + return Parse(tokens, [State(state.heads, state.deps, state.stack, state.queue)], + actions, **kwargs) + + @classmethod + def with_history(cls, text): + tokens = NLU.tokenizer(text) + NLU.tagger(tokens) + NLU.matcher(tokens) + + with NLU.parser.step_through(tokens) as state: + states = [] + while not state.is_final: + action = state.predict() + state.transition(action) + states.append(State(state.heads, state.deps, state.stack, state.queue)) + actions = [ + {'label': 'prev', 'key': 'P', 'binding': 37, 'is_valid': True}, + {'label': 'next', 'key': 'N', 'binding': 39, 'is_valid': True} + ] + return Parse(state.doc, states, actions) + + +class Word(Model): + def __init__(self, token, is_w_edit=False, is_t_edit=False): + self.word = token.orth_ + self.tag = token.pos_ + self.tag = token.pos_ if not token.ent_type_ else token.ent_type_ + self.is_entity = token.ent_iob in (1, 3) + self.is_w_edit = is_w_edit + self.is_t_edit = is_t_edit + self.prob = token.prob + + +class State(Model): + def __init__(self, heads, deps, stack, queue): + Model.__init__(self) + + queue = [w for w in queue if w >= 0] + self.focus = min(queue) if queue else -1 + self.is_final = bool(not stack and not queue) + self.stack = set(stack) + self.arrows = self._get_arrows(heads, deps) + + @classmethod + def from_doc(cls, doc): + return cls([w.head.i for w in doc], [w.dep_ for w in doc], [], []) + + def _get_arrows(self, heads, deps): + arcs = defaultdict(dict) + for i, (head, dep) in enumerate(zip(heads, deps)): + if i < head: + arcs[head - i][i] = Arrow(i, head, dep) + elif i > head: + arcs[i - head][head] = Arrow(i, head, dep) + output = [] + for level in range(1, len(heads)): + level_arcs = [] + for i in range(len(heads) - level): + level_arcs.append(arcs[level].get(i)) + output.append(level_arcs) + while output and all(arc is None for arc in output[-1]): + output.pop() + return output + + +class Arrow(Model): + def __init__(self, word, head, label): + self.dir = 'left' if head > word else 'right' + self.label = label + + +class Endpoint(object): + def set_header(self, resp): + resp.content_type = 'text/string' + resp.append_header('Access-Control-Allow-Origin', "*") + resp.status = falcon.HTTP_200 + + def set_body(self, resp, parse): + resp.body = json.dumps(parse.to_json(), indent=4) + + def on_get(self, req, resp, text): + if not isinstance(text, unicode): + text = text.decode('utf8') + self.set_body(resp, self.get_parse(text)) + self.set_header(resp) + + def on_post(self, req, resp): + try: + body_bytes = req.stream.read() + json_data = json.loads(body_bytes.decode('utf8')) + text = json_data['text'] + if not isinstance(text, unicode): + text = text.decode('utf8') + self.set_body(resp, self.get_parse(text)) + self.set_header(resp) + except: + pass + + +class ParseEP(Endpoint): + def get_parse(self, text, **kwargs): + return Parse.from_text(text, **kwargs) + + +class StepsEP(Endpoint): + def get_parse(self, text): + print('Step=', repr(text)) + return Parse.with_history(text) + + +class ManualEP(Endpoint): + def get_parse(self, text, **kwargs): + print('Manual=', repr(text)) + if '/' in text: + text, actions = text.rsplit('/', 1) + else: + actions = '' + return Parse.from_history(text, actions, **kwargs) + + def on_get(self, req, resp, text, actions=''): + if not isinstance(text, unicode): + text = text.decode('utf8') + self.set_body(resp, self.get_parse(text + '/' + actions)) + self.set_header(resp) + + def on_post(self, req, resp): + self.set_header(resp) + body_bytes = req.stream.read() + json_data = json.loads(body_bytes.decode('utf8')) + print(json_data) + params = json_data.get('params', {}) + self.set_body(resp, self.get_parse(json_data['text'], **params)) + + +app = falcon.API() + +remote_man = ManualEP() +remote_parse = ParseEP() +remote_steps = StepsEP() + +app.add_route('/api/displacy/parse/', remote_parse) +app.add_route('/api/displacy/parse/{text}/', remote_parse) + +app.add_route('/api/displacy/steps/', remote_steps) +app.add_route('/api/displacy/steps/{text}/', remote_steps) + +app.add_route('/api/displacy/manual/', remote_man) +app.add_route('/api/displacy/manual/{text}/', remote_man) +app.add_route('/api/displacy/manual/{text}/{actions}', remote_man) + + +if __name__ == '__main__': + text, actions = open(sys.argv[1]).read().strip().split('\n') + parse = Parse.from_text(text) + pprint.pprint(parse.to_json())