* Rework interfaces in vocab

This commit is contained in:
Matthew Honnibal 2015-08-26 19:21:46 +02:00
parent 2d521768a3
commit 1302d35dff
1 changed files with 33 additions and 20 deletions

View File

@ -21,6 +21,7 @@ from .cfile cimport CFile
from cymem.cymem cimport Address from cymem.cymem cimport Address
from . import util from . import util
from .serialize.packer cimport Packer from .serialize.packer cimport Packer
from .attrs cimport PROB
DEF MAX_VEC_SIZE = 100000 DEF MAX_VEC_SIZE = 100000
@ -35,27 +36,37 @@ EMPTY_LEXEME.repvec = EMPTY_VEC
cdef class Vocab: cdef class Vocab:
'''A map container for a language's LexemeC structs. '''A map container for a language's LexemeC structs.
''' '''
def __init__(self, data_dir=None, get_lex_attr=None, load_vectors=True, pos_tags=None): @classmethod
def default_morphology(cls):
return Morphology({'VBZ': ['VERB', {}]}, [], None)
def __init__(self, get_lex_attr=None, morphology=None, vectors=None):
self.get_lex_attr = get_lex_attr
if morphology is None:
morphology = self.default_morphology()
self.morphology = morphology
self.mem = Pool() self.mem = Pool()
self._by_hash = PreshMap() self._by_hash = PreshMap()
self._by_orth = PreshMap() self._by_orth = PreshMap()
self.strings = StringStore() self.strings = StringStore()
self.get_lex_attr = get_lex_attr
self.repvec_length = 0
self.length = 1 self.length = 1
self.pos_tags = pos_tags
if data_dir is not None:
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
self.load_lexemes(path.join(data_dir, 'strings.txt'),
path.join(data_dir, 'lexemes.bin'))
if load_vectors and path.exists(path.join(data_dir, 'vec.bin')):
self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
self._serializer = None self._serializer = None
self.data_dir = data_dir
@classmethod
def from_dir(cls, data_dir, get_lex_attr=None, morphology=None, vectors=None):
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
cdef Vocab self = cls(get_lex_attr=get_lex_attr, vectors=vectors,
morphology=morphology)
self.load_lexemes(path.join(data_dir, 'strings.txt'),
path.join(data_dir, 'lexemes.bin'))
if vectors is None and path.exists(path.join(data_dir, 'vec.bin')):
self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
return self
property serializer: property serializer:
def __get__(self): def __get__(self):
@ -83,7 +94,6 @@ cdef class Vocab:
lex = <LexemeC*>self._by_hash.get(key) lex = <LexemeC*>self._by_hash.get(key)
cdef size_t addr cdef size_t addr
if lex != NULL: if lex != NULL:
print string, lex.orth, self.strings[string]
assert lex.orth == self.strings[string] assert lex.orth == self.strings[string]
return lex return lex
else: else:
@ -106,17 +116,21 @@ cdef class Vocab:
cdef hash_t key cdef hash_t key
cdef bint is_oov = mem is not self.mem cdef bint is_oov = mem is not self.mem
mem = self.mem mem = self.mem
#if len(string) < 3: if len(string) < 3:
# mem = self.mem mem = self.mem
lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1) lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1)
lex.orth = self.strings[string] lex.orth = self.strings[string]
lex.length = len(string)
lex.id = self.length lex.id = self.length
if self.get_lex_attr is not None: if self.get_lex_attr is not None:
for attr, func in self.get_lex_attr.items(): for attr, func in self.get_lex_attr.items():
value = func(string) value = func(string)
if isinstance(value, unicode): if isinstance(value, unicode):
value = self.strings[value] value = self.strings[value]
Lexeme.set_struct_attr(lex, attr, value) if attr == PROB:
lex.prob = value
else:
Lexeme.set_struct_attr(lex, attr, value)
if is_oov: if is_oov:
lex.id = 0 lex.id = 0
else: else:
@ -128,7 +142,6 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1: cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1:
self._by_hash.set(key, <void*>lex) self._by_hash.set(key, <void*>lex)
self._by_orth.set(lex.orth, <void*>lex) self._by_orth.set(lex.orth, <void*>lex)
print "Add lex", key, lex.orth, self.strings[lex.orth]
self.length += 1 self.length += 1
def __iter__(self): def __iter__(self):