* Add flag to disable loading of word vectors

This commit is contained in:
Matthew Honnibal 2015-05-25 01:02:42 +02:00
parent 89c3364041
commit eba7b34f66
2 changed files with 4 additions and 4 deletions

View File

@ -64,12 +64,12 @@ class English(object):
ParserTransitionSystem = ArcEager ParserTransitionSystem = ArcEager
EntityTransitionSystem = BiluoPushDown EntityTransitionSystem = BiluoPushDown
def __init__(self, data_dir=''): def __init__(self, data_dir='', load_vectors=True):
if data_dir == '': if data_dir == '':
data_dir = LOCAL_DATA_DIR data_dir = LOCAL_DATA_DIR
self._data_dir = data_dir self._data_dir = data_dir
self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None, self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None,
get_lex_props=get_lex_props) get_lex_props=get_lex_props, load_vectors=load_vectors)
tag_names = list(POS_TAGS.keys()) tag_names = list(POS_TAGS.keys())
tag_names.sort() tag_names.sort()
if data_dir is None: if data_dir is None:

View File

@ -30,7 +30,7 @@ EMPTY_LEXEME.repvec = EMPTY_VEC
cdef class Vocab: cdef class Vocab:
'''A map container for a language's LexemeC structs. '''A map container for a language's LexemeC structs.
''' '''
def __init__(self, data_dir=None, get_lex_props=None): def __init__(self, data_dir=None, get_lex_props=None, load_vectors=True):
self.mem = Pool() self.mem = Pool()
self._map = PreshMap(2 ** 20) self._map = PreshMap(2 ** 20)
self.strings = StringStore() self.strings = StringStore()
@ -45,7 +45,7 @@ cdef class Vocab:
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir) raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
self.load_lexemes(path.join(data_dir, 'strings.txt'), self.load_lexemes(path.join(data_dir, 'strings.txt'),
path.join(data_dir, 'lexemes.bin')) path.join(data_dir, 'lexemes.bin'))
if path.exists(path.join(data_dir, 'vec.bin')): if load_vectors and path.exists(path.join(data_dir, 'vec.bin')):
self.load_rep_vectors(path.join(data_dir, 'vec.bin')) self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
def __len__(self): def __len__(self):