diff --git a/spacy/errors.py b/spacy/errors.py index 75e24789f..a58a593dc 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -204,6 +204,11 @@ class Warnings(metaclass=ErrorsWithCodes): "for the corpora used to train the language. Please check " "`nlp.meta[\"sources\"]` for any relevant links.") W119 = ("Overriding pipe name in `config` is not supported. Ignoring override '{name_in_config}'.") + W120 = ("Unable to load all spans in Doc.spans: more than one span group " + "with the name '{group_name}' was found in the saved spans data. " + "Only the last span group will be loaded under " + "Doc.spans['{group_name}']. Skipping span group with values: " + "{group_values}") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/tests/serialize/test_serialize_span_groups.py b/spacy/tests/serialize/test_serialize_span_groups.py new file mode 100644 index 000000000..85313fcdc --- /dev/null +++ b/spacy/tests/serialize/test_serialize_span_groups.py @@ -0,0 +1,161 @@ +import pytest + +from spacy.tokens import Span, SpanGroup +from spacy.tokens._dict_proxies import SpanGroups + + +@pytest.mark.issue(10685) +def test_issue10685(en_tokenizer): + """Test `SpanGroups` de/serialization""" + # Start with a Doc with no SpanGroups + doc = en_tokenizer("Will it blend?") + + # Test empty `SpanGroups` de/serialization: + assert len(doc.spans) == 0 + doc.spans.from_bytes(doc.spans.to_bytes()) + assert len(doc.spans) == 0 + + # Test non-empty `SpanGroups` de/serialization: + doc.spans["test"] = SpanGroup(doc, name="test", spans=[doc[0:1]]) + doc.spans["test2"] = SpanGroup(doc, name="test", spans=[doc[1:2]]) + + def assert_spangroups(): + assert len(doc.spans) == 2 + assert doc.spans["test"].name == "test" + assert doc.spans["test2"].name == "test" + assert list(doc.spans["test"]) == [doc[0:1]] + assert list(doc.spans["test2"]) == [doc[1:2]] + + # Sanity check the currently-expected behavior + assert_spangroups() + + # Now test serialization/deserialization: + doc.spans.from_bytes(doc.spans.to_bytes()) + + assert_spangroups() + + +def test_span_groups_serialization_mismatches(en_tokenizer): + """Test the serialization of multiple mismatching `SpanGroups` keys and `SpanGroup.name`s""" + doc = en_tokenizer("How now, brown cow?") + # Some variety: + # 1 SpanGroup where its name matches its key + # 2 SpanGroups that have the same name--which is not a key + # 2 SpanGroups that have the same name--which is a key + # 1 SpanGroup that is a value for 2 different keys (where its name is a key) + # 1 SpanGroup that is a value for 2 different keys (where its name is not a key) + groups = doc.spans + groups["key1"] = SpanGroup(doc, name="key1", spans=[doc[0:1], doc[1:2]]) + groups["key2"] = SpanGroup(doc, name="too", spans=[doc[3:4], doc[4:5]]) + groups["key3"] = SpanGroup(doc, name="too", spans=[doc[1:2], doc[0:1]]) + groups["key4"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) + groups["key5"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) + sg6 = SpanGroup(doc, name="key6", spans=[doc[0:1]]) + groups["key6"] = sg6 + groups["key7"] = sg6 + sg8 = SpanGroup(doc, name="also", spans=[doc[1:2]]) + groups["key8"] = sg8 + groups["key9"] = sg8 + + regroups = SpanGroups(doc).from_bytes(groups.to_bytes()) + + # Assert regroups == groups + assert regroups.keys() == groups.keys() + for key, regroup in regroups.items(): + # Assert regroup == groups[key] + assert regroup.name == groups[key].name + assert list(regroup) == list(groups[key]) + + +@pytest.mark.parametrize( + "spans_bytes,doc_text,expected_spangroups,expected_warning", + # The bytestrings below were generated from an earlier version of spaCy + # that serialized `SpanGroups` as a list of SpanGroup bytes (via SpanGroups.to_bytes). + # Comments preceding the bytestrings indicate from what Doc they were created. + [ + # Empty SpanGroups: + (b"\x90", "", {}, False), + # doc = nlp("Will it blend?") + # doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]]) + ( + b"\x91\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04", + "Will it blend?", + {"test": {"name": "test", "spans": [(0, 1)]}}, + False, + ), + # doc = nlp("Will it blend?") + # doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]]) + # doc.spans['test2'] = SpanGroup(doc, name='test', spans=[doc[1:2]]) + ( + b"\x92\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x07", + "Will it blend?", + # We expect only 1 SpanGroup to be in doc.spans in this example + # because there are 2 `SpanGroup`s that have the same .name. See #10685. + {"test": {"name": "test", "spans": [(1, 2)]}}, + True, + ), + # doc = nlp('How now, brown cow?') + # doc.spans['key1'] = SpanGroup(doc, name='key1', spans=[doc[0:1], doc[1:2]]) + # doc.spans['key2'] = SpanGroup(doc, name='too', spans=[doc[3:4], doc[4:5]]) + # doc.spans['key3'] = SpanGroup(doc, name='too', spans=[doc[1:2], doc[0:1]]) + # doc.spans['key4'] = SpanGroup(doc, name='key4', spans=[doc[0:1]]) + # doc.spans['key5'] = SpanGroup(doc, name='key4', spans=[doc[0:1]]) + ( + b"\x95\xc4m\x83\xa4name\xa4key1\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0e\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00\x0f\x00\x00\x00\x12\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03", + "How now, brown cow?", + { + "key1": {"name": "key1", "spans": [(0, 1), (1, 2)]}, + "too": {"name": "too", "spans": [(1, 2), (0, 1)]}, + "key4": {"name": "key4", "spans": [(0, 1)]}, + }, + True, + ), + ], +) +def test_deserialize_span_groups_compat( + en_tokenizer, spans_bytes, doc_text, expected_spangroups, expected_warning +): + """Test backwards-compatibility of `SpanGroups` deserialization. + This uses serializations (bytes) from a prior version of spaCy (before 3.3.1). + + spans_bytes (bytes): Serialized `SpanGroups` object. + doc_text (str): Doc text. + expected_spangroups (dict): + Dict mapping every expected (after deserialization) `SpanGroups` key + to a SpanGroup's "args", where a SpanGroup's args are given as a dict: + {"name": span_group.name, + "spans": [(span0.start, span0.end), ...]} + expected_warning (bool): Whether a warning is to be expected from .from_bytes() + --i.e. if more than 1 SpanGroup has the same .name within the `SpanGroups`. + """ + doc = en_tokenizer(doc_text) + + if expected_warning: + with pytest.warns(UserWarning): + doc.spans.from_bytes(spans_bytes) + else: + # TODO: explicitly check for lack of a warning + doc.spans.from_bytes(spans_bytes) + + assert doc.spans.keys() == expected_spangroups.keys() + for name, spangroup_args in expected_spangroups.items(): + assert doc.spans[name].name == spangroup_args["name"] + spans = [Span(doc, start, end) for start, end in spangroup_args["spans"]] + assert list(doc.spans[name]) == spans + + +def test_span_groups_serialization(en_tokenizer): + doc = en_tokenizer("0 1 2 3 4 5 6") + span_groups = SpanGroups(doc) + spans = [doc[0:2], doc[1:3]] + sg1 = SpanGroup(doc, spans=spans) + span_groups["key1"] = sg1 + span_groups["key2"] = sg1 + span_groups["key3"] = [] + reloaded_span_groups = SpanGroups(doc).from_bytes(span_groups.to_bytes()) + assert span_groups.keys() == reloaded_span_groups.keys() + for key, value in span_groups.items(): + assert all( + span == reloaded_span + for span, reloaded_span in zip(span_groups[key], reloaded_span_groups[key]) + ) diff --git a/spacy/tokens/_dict_proxies.py b/spacy/tokens/_dict_proxies.py index d9506769b..9630da261 100644 --- a/spacy/tokens/_dict_proxies.py +++ b/spacy/tokens/_dict_proxies.py @@ -1,10 +1,11 @@ -from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING +from typing import Dict, Iterable, List, Tuple, Union, Optional, TYPE_CHECKING +import warnings import weakref from collections import UserDict import srsly from .span_group import SpanGroup -from ..errors import Errors +from ..errors import Errors, Warnings if TYPE_CHECKING: @@ -16,7 +17,7 @@ if TYPE_CHECKING: # Why inherit from UserDict instead of dict here? # Well, the 'dict' class doesn't necessarily delegate everything nicely, # for performance reasons. The UserDict is slower but better behaved. -# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww +# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/ class SpanGroups(UserDict): """A dict-like proxy held by the Doc, to control access to span groups.""" @@ -53,20 +54,52 @@ class SpanGroups(UserDict): return super().setdefault(key, default=default) def to_bytes(self) -> bytes: - # We don't need to serialize this as a dict, because the groups - # know their names. + # We serialize this as a dict in order to track the key(s) a SpanGroup + # is a value of (in a backward- and forward-compatible way), since + # a SpanGroup can have a key that doesn't match its `.name` (See #10685) if len(self) == 0: return self._EMPTY_BYTES - msg = [value.to_bytes() for value in self.values()] + msg: Dict[bytes, List[str]] = {} + for key, value in self.items(): + msg.setdefault(value.to_bytes(), []).append(key) return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data: bytes) -> "SpanGroups": - msg = [] if bytes_data == self._EMPTY_BYTES else srsly.msgpack_loads(bytes_data) + # backwards-compatibility: bytes_data may be one of: + # b'', a serialized empty list, a serialized list of SpanGroup bytes + # or a serialized dict of SpanGroup bytes -> keys + msg = ( + [] + if not bytes_data or bytes_data == self._EMPTY_BYTES + else srsly.msgpack_loads(bytes_data) + ) self.clear() doc = self._ensure_doc() - for value_bytes in msg: - group = SpanGroup(doc).from_bytes(value_bytes) - self[group.name] = group + if isinstance(msg, list): + # This is either the 1st version of `SpanGroups` serialization + # or there were no SpanGroups serialized + for value_bytes in msg: + group = SpanGroup(doc).from_bytes(value_bytes) + if group.name in self: + # Display a warning if `msg` contains `SpanGroup`s + # that have the same .name (attribute). + # Because, for `SpanGroups` serialized as lists, + # only 1 SpanGroup per .name is loaded. (See #10685) + warnings.warn( + Warnings.W120.format( + group_name=group.name, group_values=self[group.name] + ) + ) + self[group.name] = group + else: + for value_bytes, keys in msg.items(): + group = SpanGroup(doc).from_bytes(value_bytes) + # Deserialize `SpanGroup`s as copies because it's possible for two + # different `SpanGroup`s (pre-serialization) to have the same bytes + # (since they can have the same `.name`). + self[keys[0]] = group + for key in keys[1:]: + self[key] = group.copy() return self def _ensure_doc(self) -> "Doc": diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi index 26efc3ba0..245eb4dbe 100644 --- a/spacy/tokens/span_group.pyi +++ b/spacy/tokens/span_group.pyi @@ -24,3 +24,4 @@ class SpanGroup: def __getitem__(self, i: int) -> Span: ... def to_bytes(self) -> bytes: ... def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... + def copy(self) -> SpanGroup: ...