* Hacks to conll.pyx. Should clean these up.

This commit is contained in:
Matthew Honnibal 2015-03-08 00:14:48 -05:00
parent f321b2b2eb
commit 5278c7504b
1 changed files with 17 additions and 1 deletions

View File

@ -12,12 +12,25 @@ cdef class GoldParse:
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int)) self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int)) self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
@property
def n_non_punct(self):
return len([l for l in self.labels if l != 'P'])
@property
def py_heads(self):
return [self.c_heads[i] for i in range(self.length)]
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1: cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
n = 0 n = 0
for i in range(self.length): for i in range(self.length):
if not score_punct and self.labels[i] == 'P':
continue
n += (i + tokens[i].head) == self.c_heads[i] n += (i + tokens[i].head) == self.c_heads[i]
return n return n
def is_correct(self, i, head):
return head == self.c_heads[i]
@classmethod @classmethod
def from_conll(cls, unicode sent_str): def from_conll(cls, unicode sent_str):
ids = [] ids = []
@ -96,6 +109,10 @@ cdef class GoldParse:
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int)) self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int)) self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
self.ids = [token.idx for token in tokens] self.ids = [token.idx for token in tokens]
self.map_heads(label_ids)
return self.loss
def map_heads(self, label_ids):
mapped_heads = _map_indices_to_tokens(self.ids, self.heads) mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
for i in range(self.length): for i in range(self.length):
if mapped_heads[i] is None: if mapped_heads[i] is None:
@ -121,7 +138,6 @@ def _map_indices_to_tokens(ids, heads):
return mapped return mapped
def _parse_line(line): def _parse_line(line):
pieces = line.split() pieces = line.split()
if len(pieces) == 4: if len(pieces) == 4: