Make vocab update in get_docs deterministic (#7603)

* Make vocab update in get_docs deterministic

The attribute `DocBin.strings` is a set. In `DocBin.get_docs`
a given vocab is updated by iterating over this set.
Iteration over a python set produces an arbitrary ordering,
therefore vocab is updated non-deterministically.

When training (fine-tuning) a spacy model, the base model's
vocabulary will be updated with the new vocabulary in the
training data in exactly the way described above. After
serialization, the file `model/vocab/strings.json` will
be sorted in an arbitrary way. This prevents reproducible
model training.

* Revert "Make vocab update in get_docs deterministic"

This reverts commit d6b87a2f55.

* Sort strings in StringStore serialization

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Stanislav Schmidt 2021-04-09 11:53:13 +02:00 committed by GitHub
parent 8008e2f75b
commit 2516896849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 12 deletions

View File

@ -223,7 +223,7 @@ cdef class StringStore:
it doesn't exist. Paths may be either strings or Path-like objects.
"""
path = util.ensure_path(path)
strings = list(self)
strings = sorted(self)
srsly.write_json(path, strings)
def from_disk(self, path):
@ -247,7 +247,7 @@ cdef class StringStore:
RETURNS (bytes): The serialized form of the `StringStore` object.
"""
return srsly.json_dumps(list(self))
return srsly.json_dumps(sorted(self))
def from_bytes(self, bytes_data, **kwargs):
"""Load state from a binary string.

View File

@ -49,9 +49,9 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
vocab1_d = Vocab().from_disk(file_path1)
vocab2_d = Vocab().from_disk(file_path2)
# check strings rather than lexemes, which are only reloaded on demand
assert strings1 == [s for s in vocab1_d.strings]
assert strings2 == [s for s in vocab2_d.strings]
if strings1 == strings2:
assert set(strings1) == set([s for s in vocab1_d.strings])
assert set(strings2) == set([s for s in vocab2_d.strings])
if set(strings1) == set(strings2):
assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings]
else:
assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings]
@ -96,7 +96,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2):
sstore2 = StringStore(strings=strings2)
sstore1_b = sstore1.to_bytes()
sstore2_b = sstore2.to_bytes()
if strings1 == strings2:
if set(strings1) == set(strings2):
assert sstore1_b == sstore2_b
else:
assert sstore1_b != sstore2_b
@ -104,7 +104,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2):
assert sstore1.to_bytes() == sstore1_b
new_sstore1 = StringStore().from_bytes(sstore1_b)
assert new_sstore1.to_bytes() == sstore1_b
assert list(new_sstore1) == strings1
assert set(new_sstore1) == set(strings1)
@pytest.mark.parametrize("strings1,strings2", test_strings)
@ -118,12 +118,12 @@ def test_serialize_stringstore_roundtrip_disk(strings1, strings2):
sstore2.to_disk(file_path2)
sstore1_d = StringStore().from_disk(file_path1)
sstore2_d = StringStore().from_disk(file_path2)
assert list(sstore1_d) == list(sstore1)
assert list(sstore2_d) == list(sstore2)
if strings1 == strings2:
assert list(sstore1_d) == list(sstore2_d)
assert set(sstore1_d) == set(sstore1)
assert set(sstore2_d) == set(sstore2)
if set(strings1) == set(strings2):
assert set(sstore1_d) == set(sstore2_d)
else:
assert list(sstore1_d) != list(sstore2_d)
assert set(sstore1_d) != set(sstore2_d)
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)