* Rework interfaces in vocab

This commit is contained in:
Matthew Honnibal 2015-08-26 19:21:46 +02:00
parent 2d521768a3
commit 1302d35dff
1 changed files with 33 additions and 20 deletions

View File

@ -21,6 +21,7 @@ from .cfile cimport CFile
from cymem.cymem cimport Address
from . import util
from .serialize.packer cimport Packer
from .attrs cimport PROB
DEF MAX_VEC_SIZE = 100000
@ -35,27 +36,37 @@ EMPTY_LEXEME.repvec = EMPTY_VEC
cdef class Vocab:
'''A map container for a language's LexemeC structs.
'''
def __init__(self, data_dir=None, get_lex_attr=None, load_vectors=True, pos_tags=None):
@classmethod
def default_morphology(cls):
return Morphology({'VBZ': ['VERB', {}]}, [], None)
def __init__(self, get_lex_attr=None, morphology=None, vectors=None):
self.get_lex_attr = get_lex_attr
if morphology is None:
morphology = self.default_morphology()
self.morphology = morphology
self.mem = Pool()
self._by_hash = PreshMap()
self._by_orth = PreshMap()
self.strings = StringStore()
self.get_lex_attr = get_lex_attr
self.repvec_length = 0
self.length = 1
self.pos_tags = pos_tags
if data_dir is not None:
self._serializer = None
@classmethod
def from_dir(cls, data_dir, get_lex_attr=None, morphology=None, vectors=None):
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
cdef Vocab self = cls(get_lex_attr=get_lex_attr, vectors=vectors,
morphology=morphology)
self.load_lexemes(path.join(data_dir, 'strings.txt'),
path.join(data_dir, 'lexemes.bin'))
if load_vectors and path.exists(path.join(data_dir, 'vec.bin')):
if vectors is None and path.exists(path.join(data_dir, 'vec.bin')):
self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
self._serializer = None
self.data_dir = data_dir
return self
property serializer:
def __get__(self):
@ -83,7 +94,6 @@ cdef class Vocab:
lex = <LexemeC*>self._by_hash.get(key)
cdef size_t addr
if lex != NULL:
print string, lex.orth, self.strings[string]
assert lex.orth == self.strings[string]
return lex
else:
@ -106,16 +116,20 @@ cdef class Vocab:
cdef hash_t key
cdef bint is_oov = mem is not self.mem
mem = self.mem
#if len(string) < 3:
# mem = self.mem
if len(string) < 3:
mem = self.mem
lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1)
lex.orth = self.strings[string]
lex.length = len(string)
lex.id = self.length
if self.get_lex_attr is not None:
for attr, func in self.get_lex_attr.items():
value = func(string)
if isinstance(value, unicode):
value = self.strings[value]
if attr == PROB:
lex.prob = value
else:
Lexeme.set_struct_attr(lex, attr, value)
if is_oov:
lex.id = 0
@ -128,7 +142,6 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1:
self._by_hash.set(key, <void*>lex)
self._by_orth.set(lex.orth, <void*>lex)
print "Add lex", key, lex.orth, self.strings[lex.orth]
self.length += 1
def __iter__(self):