From 59000ee21dcacb091fd3493bdfe4ea57e664e110 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 13 Mar 2020 16:07:56 +0100 Subject: [PATCH] fix serialization of empty doc + unit test --- spacy/tests/regression/test_issue5141.py | 11 +++++++++++ spacy/tokens/_serialize.py | 7 +++++-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 spacy/tests/regression/test_issue5141.py diff --git a/spacy/tests/regression/test_issue5141.py b/spacy/tests/regression/test_issue5141.py new file mode 100644 index 000000000..845454583 --- /dev/null +++ b/spacy/tests/regression/test_issue5141.py @@ -0,0 +1,11 @@ +from spacy.tokens import DocBin + + +def test_issue5141(en_vocab): + """ Ensure an empty DocBin does not crash on serialization """ + doc_bin = DocBin(attrs=["DEP", "HEAD"]) + assert list(doc_bin.get_docs(en_vocab)) == [] + doc_bin_bytes = doc_bin.to_bytes() + + doc_bin_2 = DocBin().from_bytes(doc_bin_bytes) + assert list(doc_bin_2.get_docs(en_vocab)) == [] diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index 65b70d1b3..d3f49550c 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -135,10 +135,13 @@ class DocBin(object): for tokens in self.tokens: assert len(tokens.shape) == 2, tokens.shape # this should never happen lengths = [len(tokens) for tokens in self.tokens] + tokens = numpy.vstack(self.tokens) if self.tokens else numpy.asarray([]) + spaces = numpy.vstack(self.spaces) if self.spaces else numpy.asarray([]) + msg = { "attrs": self.attrs, - "tokens": numpy.vstack(self.tokens).tobytes("C"), - "spaces": numpy.vstack(self.spaces).tobytes("C"), + "tokens": tokens.tobytes("C"), + "spaces": spaces.tobytes("C"), "lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"), "strings": list(self.strings), "cats": self.cats,