mirror of https://github.com/explosion/spaCy.git
Make vocab update in get_docs deterministic (#7603)
* Make vocab update in get_docs deterministic
The attribute `DocBin.strings` is a set. In `DocBin.get_docs`
a given vocab is updated by iterating over this set.
Iteration over a python set produces an arbitrary ordering,
therefore vocab is updated non-deterministically.
When training (fine-tuning) a spacy model, the base model's
vocabulary will be updated with the new vocabulary in the
training data in exactly the way described above. After
serialization, the file `model/vocab/strings.json` will
be sorted in an arbitrary way. This prevents reproducible
model training.
* Revert "Make vocab update in get_docs deterministic"
This reverts commit d6b87a2f55
.
* Sort strings in StringStore serialization
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
8008e2f75b
commit
2516896849
|
@ -223,7 +223,7 @@ cdef class StringStore:
|
|||
it doesn't exist. Paths may be either strings or Path-like objects.
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
strings = list(self)
|
||||
strings = sorted(self)
|
||||
srsly.write_json(path, strings)
|
||||
|
||||
def from_disk(self, path):
|
||||
|
@ -247,7 +247,7 @@ cdef class StringStore:
|
|||
|
||||
RETURNS (bytes): The serialized form of the `StringStore` object.
|
||||
"""
|
||||
return srsly.json_dumps(list(self))
|
||||
return srsly.json_dumps(sorted(self))
|
||||
|
||||
def from_bytes(self, bytes_data, **kwargs):
|
||||
"""Load state from a binary string.
|
||||
|
|
|
@ -49,9 +49,9 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
|
|||
vocab1_d = Vocab().from_disk(file_path1)
|
||||
vocab2_d = Vocab().from_disk(file_path2)
|
||||
# check strings rather than lexemes, which are only reloaded on demand
|
||||
assert strings1 == [s for s in vocab1_d.strings]
|
||||
assert strings2 == [s for s in vocab2_d.strings]
|
||||
if strings1 == strings2:
|
||||
assert set(strings1) == set([s for s in vocab1_d.strings])
|
||||
assert set(strings2) == set([s for s in vocab2_d.strings])
|
||||
if set(strings1) == set(strings2):
|
||||
assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings]
|
||||
else:
|
||||
assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings]
|
||||
|
@ -96,7 +96,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2):
|
|||
sstore2 = StringStore(strings=strings2)
|
||||
sstore1_b = sstore1.to_bytes()
|
||||
sstore2_b = sstore2.to_bytes()
|
||||
if strings1 == strings2:
|
||||
if set(strings1) == set(strings2):
|
||||
assert sstore1_b == sstore2_b
|
||||
else:
|
||||
assert sstore1_b != sstore2_b
|
||||
|
@ -104,7 +104,7 @@ def test_serialize_stringstore_roundtrip_bytes(strings1, strings2):
|
|||
assert sstore1.to_bytes() == sstore1_b
|
||||
new_sstore1 = StringStore().from_bytes(sstore1_b)
|
||||
assert new_sstore1.to_bytes() == sstore1_b
|
||||
assert list(new_sstore1) == strings1
|
||||
assert set(new_sstore1) == set(strings1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||
|
@ -118,12 +118,12 @@ def test_serialize_stringstore_roundtrip_disk(strings1, strings2):
|
|||
sstore2.to_disk(file_path2)
|
||||
sstore1_d = StringStore().from_disk(file_path1)
|
||||
sstore2_d = StringStore().from_disk(file_path2)
|
||||
assert list(sstore1_d) == list(sstore1)
|
||||
assert list(sstore2_d) == list(sstore2)
|
||||
if strings1 == strings2:
|
||||
assert list(sstore1_d) == list(sstore2_d)
|
||||
assert set(sstore1_d) == set(sstore1)
|
||||
assert set(sstore2_d) == set(sstore2)
|
||||
if set(strings1) == set(strings2):
|
||||
assert set(sstore1_d) == set(sstore2_d)
|
||||
else:
|
||||
assert list(sstore1_d) != list(sstore2_d)
|
||||
assert set(sstore1_d) != set(sstore2_d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||
|
|
Loading…
Reference in New Issue