Fix calculation of L2-norm for Lexeme

This commit is contained in:
Matthew Honnibal 2016-10-23 14:44:45 +02:00
parent 7638f439e5
commit a0a4ada42a
1 changed files with 9 additions and 3 deletions

View File

@ -4,6 +4,7 @@ from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
from libc.string cimport memset
from libc.stdint cimport int32_t
from libc.stdint cimport uint64_t
from libc.math cimport sqrt
import bz2
from os import path
@ -386,6 +387,7 @@ cdef class Vocab:
cdef LexemeC* lexeme
cdef attr_t orth
cdef int32_t vec_len = -1
cdef double norm = 0.0
for line_num, line in enumerate(file_):
pieces = line.split()
word_str = pieces.pop(0)
@ -397,9 +399,12 @@ cdef class Vocab:
orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
for i, val_str in enumerate(pieces):
lexeme.vector[i] = float(val_str)
norm = 0.0
for i in range(vec_len):
norm += lexeme.vector[i] * lexeme.vector[i]
lex.l2_norm = sqrt(norm)
self.vectors_length = vec_len
return vec_len
@ -438,14 +443,15 @@ cdef class Vocab:
line_num += 1
cdef LexemeC* lex
cdef size_t lex_addr
cdef double norm = 0.0
cdef int i
for orth, lex_addr in self._by_orth.items():
lex = <LexemeC*>lex_addr
if lex.lower < vectors.size():
lex.vector = vectors[lex.lower]
for i in range(vec_len):
lex.l2_norm += (lex.vector[i] * lex.vector[i])
lex.l2_norm = math.sqrt(lex.l2_norm)
norm += lex.vector[i] * lex.vector[i]
lex.l2_norm = sqrt(norm)
else:
lex.vector = EMPTY_VEC
self.vectors_length = vec_len