Revert "Revert "Merge pull request #836 from raphael0202/load_vectors (closes #834)""

This reverts commit ea05f78660.
This commit is contained in:
ines 2017-02-16 23:26:21 +01:00
parent 2f82d68430
commit 85d249d451
2 changed files with 20 additions and 8 deletions

View File

@ -0,0 +1,15 @@
# coding: utf-8
from __future__ import unicode_literals
from io import StringIO
word2vec_str = """, -0.046107 -0.035951 -0.560418
de -0.648927 -0.400976 -0.527124
. 0.113685 0.439990 -0.634510
  -1.499184 -0.184280 -0.598371"""
def test_issue834(en_vocab):
f = StringIO(word2vec_str)
vector_length = en_vocab.load_vectors(f)
assert vector_length == 3

View File

@ -1,22 +1,17 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
from libc.string cimport memset from libc.string cimport memset
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libc.stdint cimport uint64_t
from libc.math cimport sqrt from libc.math cimport sqrt
from pathlib import Path from pathlib import Path
import bz2 import bz2
import io
import math
import ujson as json import ujson as json
import tempfile import re
from .lexeme cimport EMPTY_LEXEME from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport Lexeme from .lexeme cimport Lexeme
from .strings cimport hash_string from .strings cimport hash_string
from .orth cimport word_shape
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .cfile cimport CFile from .cfile cimport CFile
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
@ -29,7 +24,6 @@ from . import symbols
from cymem.cymem cimport Address from cymem.cymem cimport Address
from .serialize.packer cimport Packer from .serialize.packer cimport Packer
from .attrs cimport PROB, LANG from .attrs cimport PROB, LANG
from . import deprecated
from . import util from . import util
@ -477,9 +471,12 @@ cdef class Vocab:
cdef attr_t orth cdef attr_t orth
cdef int32_t vec_len = -1 cdef int32_t vec_len = -1
cdef double norm = 0.0 cdef double norm = 0.0
whitespace_pattern = re.compile(r'\s')
for line_num, line in enumerate(file_): for line_num, line in enumerate(file_):
pieces = line.split() pieces = line.split()
word_str = " " if line.startswith(" ") else pieces.pop(0) word_str = " " if whitespace_pattern.match(line) else pieces.pop(0)
if vec_len == -1: if vec_len == -1:
vec_len = len(pieces) vec_len = len(pieces)
elif vec_len != len(pieces): elif vec_len != len(pieces):