Make nonproj methods top-level functions, instead of class methods

This commit is contained in:
Matthew Honnibal 2017-05-22 04:48:02 -05:00
parent c998776c25
commit 2a5eb9f61e
4 changed files with 126 additions and 133 deletions

View File

@ -173,12 +173,11 @@ class GoldCorpus(object):
if shuffle: if shuffle:
random.shuffle(self.train_locs) random.shuffle(self.train_locs)
if projectivize: if projectivize:
train_tuples = nonproj.PseudoProjectivity.preprocess_training_data( train_tuples = nonproj.preprocess_training_data(
self.train_tuples) self.train_tuples)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc)
if shuffle: if shuffle:
gold_docs = util.itershuffle(gold_docs, bufsize=shuffle*1000) random.shuffle(train_tuples)
gold_docs = nlp.preprocess_gold(gold_docs) gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc)
yield from gold_docs yield from gold_docs
def dev_docs(self, nlp): def dev_docs(self, nlp):
@ -236,7 +235,7 @@ class GoldCorpus(object):
return locs return locs
def read_json_file(loc, docs_filter=None, limit=1000): def read_json_file(loc, docs_filter=None, limit=None):
loc = ensure_path(loc) loc = ensure_path(loc)
if loc.is_dir(): if loc.is_dir():
for filename in loc.iterdir(): for filename in loc.iterdir():
@ -390,7 +389,7 @@ cdef class GoldParse:
raise Exception("Cycle found: %s" % cycle) raise Exception("Cycle found: %s" % cycle)
if make_projective: if make_projective:
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads, self.labels) proj_heads,_ = nonproj.projectivize(self.heads, self.labels)
self.heads = proj_heads self.heads = proj_heads
def __len__(self): def __len__(self):

View File

@ -13,7 +13,7 @@ from .vocab import Vocab
from .tagger import Tagger from .tagger import Tagger
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.import nonproj
from .pipeline import NeuralDependencyParser, EntityRecognizer from .pipeline import NeuralDependencyParser, EntityRecognizer
from .pipeline import TokenVectorEncoder, NeuralTagger, NeuralEntityRecognizer from .pipeline import TokenVectorEncoder, NeuralTagger, NeuralEntityRecognizer
from .pipeline import NeuralLabeller from .pipeline import NeuralLabeller
@ -97,7 +97,7 @@ class BaseDefaults(object):
'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)], 'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
'dependencies': lambda nlp, **cfg: [ 'dependencies': lambda nlp, **cfg: [
NeuralDependencyParser(nlp.vocab, **cfg), NeuralDependencyParser(nlp.vocab, **cfg),
PseudoProjectivity.deprojectivize], nonproj.deprojectivize],
'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)], 'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
} }

View File

@ -47,7 +47,7 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .nonproj import PseudoProjectivity from . import nonproj
from .transition_system import OracleError from .transition_system import OracleError
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from ..structs cimport TokenC from ..structs cimport TokenC
@ -435,7 +435,7 @@ cdef class Parser:
def begin_training(self, gold_tuples, **cfg): def begin_training(self, gold_tuples, **cfg):
if 'model' in cfg: if 'model' in cfg:
self.model = cfg['model'] self.model = cfg['model']
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) gold_tuples = nonproj.preprocess_training_data(gold_tuples)
actions = self.moves.get_actions(gold_parses=gold_tuples) actions = self.moves.get_actions(gold_parses=gold_tuples)
for action, labels in actions.items(): for action, labels in actions.items():
for label in labels: for label in labels:

View File

@ -1,10 +1,17 @@
# coding: utf-8 # coding: utf-8
"""
Implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
for doing pseudo-projective parsing implementation uses the HEAD decoration
scheme.
"""
from __future__ import unicode_literals from __future__ import unicode_literals
from copy import copy from copy import copy
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..attrs import DEP, HEAD from ..attrs import DEP, HEAD
DELIMITER = '||'
def ancestors(tokenid, heads): def ancestors(tokenid, heads):
# returns all words going from the word up the path to the root # returns all words going from the word up the path to the root
@ -60,89 +67,79 @@ def is_nonproj_tree(heads):
return any( is_nonproj_arc(word,heads) for word in range(len(heads)) ) return any( is_nonproj_arc(word,heads) for word in range(len(heads)) )
class PseudoProjectivity: def decompose(label):
# implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005 return label.partition(DELIMITER)[::2]
# for doing pseudo-projective parsing
# implementation uses the HEAD decoration scheme
delimiter = '||'
@classmethod def is_decorated(label):
def decompose(cls, label): return label.find(DELIMITER) != -1
return label.partition(cls.delimiter)[::2]
@classmethod
def is_decorated(cls, label):
return label.find(cls.delimiter) != -1
@classmethod def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
def preprocess_training_data(cls, gold_tuples, label_freq_cutoff=30):
preprocessed = [] preprocessed = []
freqs = {} freqs = {}
for raw_text, sents in gold_tuples: for raw_text, sents in gold_tuples:
prepro_sents = [] prepro_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents: for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads,deco_labels = cls.projectivize(heads,labels) proj_heads,deco_labels = projectivize(heads,labels)
# set the label to ROOT for each root dependent # set the label to ROOT for each root dependent
deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ] deco_labels = [ 'ROOT' if head == i else deco_labels[i] for i,head in enumerate(proj_heads) ]
# count label frequencies # count label frequencies
if label_freq_cutoff > 0: if label_freq_cutoff > 0:
for label in deco_labels: for label in deco_labels:
if cls.is_decorated(label): if is_decorated(label):
freqs[label] = freqs.get(label,0) + 1 freqs[label] = freqs.get(label,0) + 1
prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts)) prepro_sents.append(((ids,words,tags,proj_heads,deco_labels,iob), ctnts))
preprocessed.append((raw_text, prepro_sents)) preprocessed.append((raw_text, prepro_sents))
if label_freq_cutoff > 0: if label_freq_cutoff > 0:
return cls._filter_labels(preprocessed,label_freq_cutoff,freqs) return _filter_labels(preprocessed,label_freq_cutoff,freqs)
return preprocessed return preprocessed
@classmethod @classmethod
def projectivize(cls, heads, labels): def projectivize(heads, labels):
# use the algorithm by Nivre & Nilsson 2005 # use the algorithm by Nivre & Nilsson 2005
# assumes heads to be a proper tree, i.e. connected and cycle-free # assumes heads to be a proper tree, i.e. connected and cycle-free
# returns a new pair (heads,labels) which encode # returns a new pair (heads,labels) which encode
# a projective and decorated tree # a projective and decorated tree
proj_heads = copy(heads) proj_heads = copy(heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads) smallest_np_arc = _get_smallest_nonproj_arc(proj_heads)
if smallest_np_arc == None: # this sentence is already projective if smallest_np_arc == None: # this sentence is already projective
return proj_heads, copy(labels) return proj_heads, copy(labels)
while smallest_np_arc != None: while smallest_np_arc != None:
cls._lift(smallest_np_arc, proj_heads) _lift(smallest_np_arc, proj_heads)
smallest_np_arc = cls._get_smallest_nonproj_arc(proj_heads) smallest_np_arc = _get_smallest_nonproj_arc(proj_heads)
deco_labels = cls._decorate(heads, proj_heads, labels) deco_labels = _decorate(heads, proj_heads, labels)
return proj_heads, deco_labels return proj_heads, deco_labels
@classmethod @classmethod
def deprojectivize(cls, tokens): def deprojectivize(tokens):
# reattach arcs with decorated labels (following HEAD scheme) # reattach arcs with decorated labels (following HEAD scheme)
# for each decorated arc X||Y, search top-down, left-to-right, # for each decorated arc X||Y, search top-down, left-to-right,
# breadth-first until hitting a Y then make this the new head # breadth-first until hitting a Y then make this the new head
for token in tokens: for token in tokens:
if cls.is_decorated(token.dep_): if is_decorated(token.dep_):
newlabel,headlabel = cls.decompose(token.dep_) newlabel,headlabel = decompose(token.dep_)
newhead = cls._find_new_head(token,headlabel) newhead = _find_new_head(token,headlabel)
token.head = newhead token.head = newhead
token.dep_ = newlabel token.dep_ = newlabel
return tokens return tokens
@classmethod def _decorate(heads, proj_heads, labels):
def _decorate(cls, heads, proj_heads, labels):
# uses decoration scheme HEAD from Nivre & Nilsson 2005 # uses decoration scheme HEAD from Nivre & Nilsson 2005
assert(len(heads) == len(proj_heads) == len(labels)) assert(len(heads) == len(proj_heads) == len(labels))
deco_labels = [] deco_labels = []
for tokenid,head in enumerate(heads): for tokenid,head in enumerate(heads):
if head != proj_heads[tokenid]: if head != proj_heads[tokenid]:
deco_labels.append('%s%s%s' % (labels[tokenid],cls.delimiter,labels[head])) deco_labels.append('%s%s%s' % (labels[tokenid], DELIMITER, labels[head]))
else: else:
deco_labels.append(labels[tokenid]) deco_labels.append(labels[tokenid])
return deco_labels return deco_labels
@classmethod def _get_smallest_nonproj_arc(heads):
def _get_smallest_nonproj_arc(cls, heads):
# return the smallest non-proj arc or None # return the smallest non-proj arc or None
# where size is defined as the distance between dep and head # where size is defined as the distance between dep and head
# and ties are broken left to right # and ties are broken left to right
@ -156,8 +153,7 @@ class PseudoProjectivity:
return smallest_np_arc return smallest_np_arc
@classmethod def _lift(tokenid, heads):
def _lift(cls, tokenid, heads):
# reattaches a word to it's grandfather # reattaches a word to it's grandfather
head = heads[tokenid] head = heads[tokenid]
ghead = heads[head] ghead = heads[head]
@ -165,8 +161,7 @@ class PseudoProjectivity:
heads[tokenid] = ghead if head != ghead else tokenid heads[tokenid] = ghead if head != ghead else tokenid
@classmethod def _find_new_head(token, headlabel):
def _find_new_head(cls, token, headlabel):
# search through the tree starting from the head of the given token # search through the tree starting from the head of the given token
# returns the id of the first descendant with the given label # returns the id of the first descendant with the given label
# if there is none, return the current head (no change) # if there is none, return the current head (no change)
@ -184,15 +179,14 @@ class PseudoProjectivity:
return token.head return token.head
@classmethod def _filter_labels(gold_tuples, cutoff, freqs):
def _filter_labels(cls, gold_tuples, cutoff, freqs):
# throw away infrequent decorated labels # throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller # can't learn them reliably anyway and keeps label set smaller
filtered = [] filtered = []
for raw_text, sents in gold_tuples: for raw_text, sents in gold_tuples:
filtered_sents = [] filtered_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents: for (ids, words, tags, heads, labels, iob), ctnts in sents:
filtered_labels = [ cls.decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ] filtered_labels = [ decompose(label)[0] if freqs.get(label,cutoff) < cutoff else label for label in labels ]
filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts)) filtered_sents.append(((ids,words,tags,heads,filtered_labels,iob), ctnts))
filtered.append((raw_text, filtered_sents)) filtered.append((raw_text, filtered_sents))
return filtered return filtered