2020-10-02 13:43:32 +00:00
|
|
|
import pytest
|
|
|
|
from spacy.tokens.doc import Underscore
|
|
|
|
|
2019-10-03 12:48:45 +00:00
|
|
|
import spacy
|
|
|
|
from spacy.lang.en import English
|
|
|
|
from spacy.tokens import Doc, DocBin
|
2017-11-09 01:29:03 +00:00
|
|
|
|
2018-07-24 21:38:44 +00:00
|
|
|
from ..util import make_tempdir
|
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_empty_doc(en_vocab):
|
|
|
|
doc = Doc(en_vocab)
|
|
|
|
data = doc.to_bytes()
|
|
|
|
doc2 = Doc(en_vocab)
|
|
|
|
doc2.from_bytes(data)
|
|
|
|
assert len(doc) == len(doc2)
|
|
|
|
for token1, token2 in zip(doc, doc2):
|
|
|
|
assert token1.text == token2.text
|
2017-11-09 01:29:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_roundtrip_bytes(en_vocab):
|
2018-11-27 00:09:36 +00:00
|
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
2019-12-06 13:07:39 +00:00
|
|
|
doc.cats = {"A": 0.5}
|
2017-11-09 01:29:03 +00:00
|
|
|
doc_b = doc.to_bytes()
|
|
|
|
new_doc = Doc(en_vocab).from_bytes(doc_b)
|
|
|
|
assert new_doc.to_bytes() == doc_b
|
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_roundtrip_disk(en_vocab):
|
2018-11-27 00:09:36 +00:00
|
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
2017-11-09 01:29:03 +00:00
|
|
|
with make_tempdir() as d:
|
2018-11-27 00:09:36 +00:00
|
|
|
file_path = d / "doc"
|
2017-11-09 01:29:03 +00:00
|
|
|
doc.to_disk(file_path)
|
|
|
|
doc_d = Doc(en_vocab).from_disk(file_path)
|
|
|
|
assert doc.to_bytes() == doc_d.to_bytes()
|
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_roundtrip_disk_str_path(en_vocab):
|
2018-11-27 00:09:36 +00:00
|
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
2017-11-09 01:29:03 +00:00
|
|
|
with make_tempdir() as d:
|
2018-11-27 00:09:36 +00:00
|
|
|
file_path = d / "doc"
|
2019-12-22 00:53:56 +00:00
|
|
|
file_path = str(file_path)
|
2017-11-09 01:29:03 +00:00
|
|
|
doc.to_disk(file_path)
|
|
|
|
doc_d = Doc(en_vocab).from_disk(file_path)
|
|
|
|
assert doc.to_bytes() == doc_d.to_bytes()
|
2019-03-10 18:16:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_exclude(en_vocab):
|
|
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
|
|
|
doc.user_data["foo"] = "bar"
|
|
|
|
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
|
|
|
|
assert new_doc.user_data["foo"] == "bar"
|
|
|
|
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(), exclude=["user_data"])
|
|
|
|
assert not new_doc.user_data
|
|
|
|
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(exclude=["user_data"]))
|
|
|
|
assert not new_doc.user_data
|
2019-10-03 12:48:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_bin():
|
|
|
|
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
|
|
|
texts = ["Some text", "Lots of texts...", "..."]
|
2019-12-06 13:07:39 +00:00
|
|
|
cats = {"A": 0.5}
|
2019-10-03 12:48:45 +00:00
|
|
|
nlp = English()
|
|
|
|
for doc in nlp.pipe(texts):
|
2019-12-06 13:07:39 +00:00
|
|
|
doc.cats = cats
|
2019-10-03 12:48:45 +00:00
|
|
|
doc_bin.add(doc)
|
|
|
|
bytes_data = doc_bin.to_bytes()
|
|
|
|
|
|
|
|
# Deserialize later, e.g. in a new process
|
|
|
|
nlp = spacy.blank("en")
|
|
|
|
doc_bin = DocBin().from_bytes(bytes_data)
|
2019-12-06 13:07:39 +00:00
|
|
|
reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
|
|
|
|
for i, doc in enumerate(reloaded_docs):
|
|
|
|
assert doc.text == texts[i]
|
|
|
|
assert doc.cats == cats
|
2020-07-03 10:58:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_serialize_doc_bin_unknown_spaces(en_vocab):
|
|
|
|
doc1 = Doc(en_vocab, words=["that", "'s"])
|
|
|
|
assert doc1.has_unknown_spaces
|
|
|
|
assert doc1.text == "that 's "
|
|
|
|
doc2 = Doc(en_vocab, words=["that", "'s"], spaces=[False, False])
|
|
|
|
assert not doc2.has_unknown_spaces
|
|
|
|
assert doc2.text == "that's"
|
|
|
|
|
|
|
|
doc_bin = DocBin().from_bytes(DocBin(docs=[doc1, doc2]).to_bytes())
|
|
|
|
re_doc1, re_doc2 = doc_bin.get_docs(en_vocab)
|
|
|
|
assert re_doc1.has_unknown_spaces
|
|
|
|
assert re_doc1.text == "that 's "
|
|
|
|
assert not re_doc2.has_unknown_spaces
|
|
|
|
assert re_doc2.text == "that's"
|
2020-10-02 13:43:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2020-10-03 15:20:18 +00:00
|
|
|
"writer_flag,reader_flag,reader_value",
|
|
|
|
[
|
|
|
|
(True, True, "bar"),
|
|
|
|
(True, False, "bar"),
|
|
|
|
(False, True, "nothing"),
|
|
|
|
(False, False, "nothing"),
|
|
|
|
],
|
2020-10-02 13:43:32 +00:00
|
|
|
)
|
|
|
|
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
|
|
|
|
"""Test that custom extensions are correctly serialized in DocBin."""
|
|
|
|
Doc.set_extension("foo", default="nothing")
|
|
|
|
doc = Doc(en_vocab, words=["hello", "world"])
|
|
|
|
doc._.foo = "bar"
|
|
|
|
doc_bin_1 = DocBin(store_user_data=writer_flag)
|
|
|
|
doc_bin_1.add(doc)
|
|
|
|
doc_bin_bytes = doc_bin_1.to_bytes()
|
|
|
|
doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
|
|
|
|
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
|
|
|
assert doc_2._.foo == reader_value
|
|
|
|
Underscore.doc_extensions = {}
|