From 6ed423c16c99206ff2b81176d9565d0e1c1b7071 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Sun, 7 Feb 2021 01:05:43 +0100 Subject: [PATCH] reduce memory load when reading all vectors from file (#6945) * reduce memory load when reading all vectors from file * one more small typo fix --- spacy/lexeme.pyx | 2 +- spacy/training/initialize.py | 16 ++++++++++------ website/docs/api/top-level.md | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index 25461b4b7..c8e0f2965 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -451,7 +451,7 @@ cdef class Lexeme: Lexeme.c_set_flag(self.c, IS_QUOTE, x) property is_left_punct: - """RETURNS (bool): Whether the lexeme is left punctuation, e.g. ).""" + """RETURNS (bool): Whether the lexeme is left punctuation, e.g. (.""" def __get__(self): return Lexeme.c_check_flag(self.c, IS_LEFT_PUNCT) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 7457dc359..25bb73c78 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -215,8 +215,7 @@ def convert_vectors( def read_vectors(vectors_loc: Path, truncate_vectors: int): - f = open_file(vectors_loc) - f = ensure_shape(f) + f = ensure_shape(vectors_loc) shape = tuple(int(size) for size in next(f).split()) if truncate_vectors >= 1: shape = (truncate_vectors, shape[1]) @@ -251,11 +250,12 @@ def open_file(loc: Union[str, Path]) -> IO: return loc.open("r", encoding="utf8") -def ensure_shape(lines): +def ensure_shape(vectors_loc): """Ensure that the first line of the data is the vectors shape. If it's not, we read in the data and output the shape as the first result, so that the reader doesn't have to deal with the problem. """ + lines = open_file(vectors_loc) first_line = next(lines) try: shape = tuple(int(size) for size in first_line.split()) @@ -269,7 +269,11 @@ def ensure_shape(lines): # Figure out the shape, make it the first value, and then give the # rest of the data. width = len(first_line.split()) - 1 - captured = [first_line] + list(lines) - length = len(captured) + length = 1 + for _ in lines: + length += 1 yield f"{length} {width}" - yield from captured + # Reading the lines in again from file. This to avoid having to + # store all the results in a list in memory + lines2 = open_file(vectors_loc) + yield from lines2 diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index 3a2c65553..37f619f3e 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -727,7 +727,7 @@ capitalization by including a mix of capitalized and lowercase examples. See the Create a data augmentation callback that uses orth-variant replacement. The callback can be added to a corpus or other data iterator during training. It's -is especially useful for punctuation and case replacement, to help generalize +especially useful for punctuation and case replacement, to help generalize beyond corpora that don't have smart quotes, or only have smart quotes etc. | Name | Description |