mirror of https://github.com/explosion/spaCy.git
prevent None in gold fields (#5425)
* set gold fields to empty list instead of keeping them as None * add unit test
This commit is contained in:
parent
113e7981d0
commit
b04738903e
|
@ -658,7 +658,15 @@ cdef class GoldParse:
|
||||||
entdoc = None
|
entdoc = None
|
||||||
|
|
||||||
# avoid allocating memory if the doc does not contain any tokens
|
# avoid allocating memory if the doc does not contain any tokens
|
||||||
if self.length > 0:
|
if self.length == 0:
|
||||||
|
self.words = []
|
||||||
|
self.tags = []
|
||||||
|
self.heads = []
|
||||||
|
self.labels = []
|
||||||
|
self.ner = []
|
||||||
|
self.morphology = []
|
||||||
|
|
||||||
|
else:
|
||||||
if words is None:
|
if words is None:
|
||||||
words = [token.text for token in doc]
|
words = [token.text for token in doc]
|
||||||
if tags is None:
|
if tags is None:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from spacy.lang.en import English
|
||||||
from spacy.pipeline import EntityRecognizer, EntityRuler
|
from spacy.pipeline import EntityRecognizer, EntityRuler
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.syntax.ner import BiluoPushDown
|
from spacy.syntax.ner import BiluoPushDown
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse, minibatch
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,6 +174,31 @@ def test_accept_blocked_token():
|
||||||
assert ner2.moves.is_valid(state2, "U-")
|
assert ner2.moves.is_valid(state2, "U-")
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_empty():
|
||||||
|
"""Test that training an empty text does not throw errors."""
|
||||||
|
train_data = [
|
||||||
|
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
||||||
|
("", {"entities": []}),
|
||||||
|
]
|
||||||
|
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.create_pipe("ner")
|
||||||
|
ner.add_label("PERSON")
|
||||||
|
nlp.add_pipe(ner, last=True)
|
||||||
|
|
||||||
|
nlp.begin_training()
|
||||||
|
for itn in range(2):
|
||||||
|
losses = {}
|
||||||
|
batches = minibatch(train_data)
|
||||||
|
for batch in batches:
|
||||||
|
texts, annotations = zip(*batch)
|
||||||
|
nlp.update(
|
||||||
|
texts, # batch of texts
|
||||||
|
annotations, # batch of annotations
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_overwrite_token():
|
def test_overwrite_token():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
ner1 = nlp.create_pipe("ner")
|
ner1 = nlp.create_pipe("ner")
|
||||||
|
|
Loading…
Reference in New Issue