From 25168968495fd8ba7485e50e786434fa17fc97fa Mon Sep 17 00:00:00 2001 From: Stanislav Schmidt Date: Fri, 9 Apr 2021 11:53:13 +0200 Subject: [PATCH] Make vocab update in get_docs deterministic (#7603) * Make vocab update in get_docs deterministic The attribute `DocBin.strings` is a set. In `DocBin.get_docs` a given vocab is updated by iterating over this set. Iteration over a python set produces an arbitrary ordering, therefore vocab is updated non-deterministically. When training (fine-tuning) a spacy model, the base model's vocabulary will be updated with the new vocabulary in the training data in exactly the way described above. After serialization, the file `model/vocab/strings.json` will be sorted in an arbitrary way. This prevents reproducible model training. * Revert "Make vocab update in get_docs deterministic" This reverts commit d6b87a2f558b52d66549b6a66c0af00e283ad628. * Sort strings in StringStore serialization Co-authored-by: Adriane Boyd --- spacy/strings.pyx | 4 ++-- .../serialize/test_serialize_vocab_strings.py | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 6a1d68221..4a20cb8af 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -223,7 +223,7 @@ cdef class StringStore: it doesn't exist. Paths may be either strings or Path-like objects. """ path = util.ensure_path(path) - strings = list(self) + strings = sorted(self) srsly.write_json(path, strings) def from_disk(self, path): @@ -247,7 +247,7 @@ cdef class StringStore: RETURNS (bytes): The serialized form of the `StringStore` object. """ - return srsly.json_dumps(list(self)) + return srsly.json_dumps(sorted(self)) def from_bytes(self, bytes_data, **kwargs): """Load state from a binary string. diff --git a/spacy/tests/serialize/test_serialize_vocab_strings.py b/spacy/tests/serialize/test_serialize_vocab_strings.py index 45a546203..3fe9363bf 100644 --- a/spacy/tests/serialize/test_serialize_vocab_strings.py +++ b/spacy/tests/serialize/test_serialize_vocab_strings.py @@ -49,9 +49,9 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2): vocab1_d = Vocab().from_disk(file_path1) vocab2_d = Vocab().from_disk(file_path2) # check strings rather than lexemes, which are only reloaded on demand - assert strings1 == [s for s in vocab1_d.strings] - assert strings2 == [s for s in vocab2_d.strings] - if strings1 == strings2: + assert set(strings1) == set([s for s in vocab1_d.strings]) + assert set(strings2) == set([s for s in vocab2_d.strings]) + if set(strings1) == set(strings2): assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings] else: assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings] @@ -96,7 +96,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2): sstore2 = StringStore(strings=strings2) sstore1_b = sstore1.to_bytes() sstore2_b = sstore2.to_bytes() - if strings1 == strings2: + if set(strings1) == set(strings2): assert sstore1_b == sstore2_b else: assert sstore1_b != sstore2_b @@ -104,7 +104,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2): assert sstore1.to_bytes() == sstore1_b new_sstore1 = StringStore().from_bytes(sstore1_b) assert new_sstore1.to_bytes() == sstore1_b - assert list(new_sstore1) == strings1 + assert set(new_sstore1) == set(strings1) @pytest.mark.parametrize("strings1,strings2", test_strings) @@ -118,12 +118,12 @@ def test_serialize_stringstore_roundtrip_disk(strings1, strings2): sstore2.to_disk(file_path2) sstore1_d = StringStore().from_disk(file_path1) sstore2_d = StringStore().from_disk(file_path2) - assert list(sstore1_d) == list(sstore1) - assert list(sstore2_d) == list(sstore2) - if strings1 == strings2: - assert list(sstore1_d) == list(sstore2_d) + assert set(sstore1_d) == set(sstore1) + assert set(sstore2_d) == set(sstore2) + if set(strings1) == set(strings2): + assert set(sstore1_d) == set(sstore2_d) else: - assert list(sstore1_d) != list(sstore2_d) + assert set(sstore1_d) != set(sstore2_d) @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)