Ensure training doesn't crash with empty batches (#4360)

* unit test for previously resolved unflatten issue

* prevent batch of empty docs to cause problems
This commit is contained in:
Sofie Van Landeghem 2019-10-02 12:50:48 +02:00 committed by Matthew Honnibal
parent 52b5912dbf
commit 9d3ce7cba2
3 changed files with 44 additions and 0 deletions

View File

@ -454,6 +454,10 @@ class Tagger(Pipe):
if losses is not None and self.name not in losses:
losses[self.name] = 0.
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
bp_tag_scores(d_tag_scores, sgd=sgd)
@ -467,6 +471,9 @@ class Tagger(Pipe):
"""
if self._rehearsal_model is None:
return
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
guesses, backprop = self.model.begin_update(docs, drop=drop)
target = self._rehearsal_model(docs)
gradient = guesses - target
@ -968,6 +975,9 @@ class TextCategorizer(Pipe):
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
self.require_model()
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
scores, bp_scores = self.model.begin_update(docs, drop=drop)
loss, d_scores = self.get_loss(docs, golds, scores)
bp_scores(d_scores, sgd=sgd)
@ -978,6 +988,9 @@ class TextCategorizer(Pipe):
def rehearse(self, docs, drop=0., sgd=None, losses=None):
if self._rehearsal_model is None:
return
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
scores, bp_scores = self.model.begin_update(docs, drop=drop)
target = self._rehearsal_model(docs)
gradient = scores - target

View File

@ -318,6 +318,14 @@ def test_issue3449():
assert t3[5].text == "I"
def test_issue3456():
# this crashed because of a padding error in layer.ops.unflatten in thinc
nlp = English()
nlp.add_pipe(nlp.create_pipe("tagger"))
nlp.begin_training()
list(nlp.pipe(['hi', '']))
def test_issue3468():
"""Test that sentence boundaries are set correctly so Doc.is_sentenced can
be restored after serialization."""

View File

@ -0,0 +1,23 @@
# coding: utf8
from __future__ import unicode_literals
from spacy.lang.en import English
from spacy.util import minibatch, compounding
def test_issue4348():
"""Test that training the tagger with empty data, doesn't throw errors"""
TRAIN_DATA = [("", {"tags": []}), ("", {"tags": []})]
nlp = English()
tagger = nlp.create_pipe("tagger")
nlp.add_pipe(tagger)
optimizer = nlp.begin_training()
for i in range(5):
losses = {}
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, losses=losses)