mirror of https://github.com/explosion/spaCy.git
Fix beam NER resizing (#6834)
* move label check to sub methods * add tests
This commit is contained in:
parent
5ed51c9dd2
commit
6b68ad027b
|
@ -205,7 +205,6 @@ cdef class Parser(TrainablePipe):
|
|||
def predict(self, docs):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
self._ensure_labels_are_added(docs)
|
||||
if not any(len(doc) for doc in docs):
|
||||
result = self.moves.init_batch(docs)
|
||||
return result
|
||||
|
@ -222,6 +221,7 @@ cdef class Parser(TrainablePipe):
|
|||
def greedy_parse(self, docs, drop=0.):
|
||||
cdef vector[StateC*] states
|
||||
cdef StateClass state
|
||||
self._ensure_labels_are_added(docs)
|
||||
set_dropout_rate(self.model, drop)
|
||||
batch = self.moves.init_batch(docs)
|
||||
model = self.model.predict(docs)
|
||||
|
@ -240,6 +240,7 @@ cdef class Parser(TrainablePipe):
|
|||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
self._ensure_labels_are_added(docs)
|
||||
batch = _beam_utils.BeamBatch(
|
||||
self.moves,
|
||||
self.moves.init_batch(docs),
|
||||
|
|
|
@ -138,6 +138,28 @@ def test_ner_labels_added_implicitly_on_predict():
|
|||
assert "D" in ner.labels
|
||||
|
||||
|
||||
def test_ner_labels_added_implicitly_on_beam_parse():
|
||||
nlp = Language()
|
||||
ner = nlp.add_pipe("beam_ner")
|
||||
for label in ["A", "B", "C"]:
|
||||
ner.add_label(label)
|
||||
nlp.initialize()
|
||||
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||
ner.beam_parse([doc], beam_width=32)
|
||||
assert "D" in ner.labels
|
||||
|
||||
|
||||
def test_ner_labels_added_implicitly_on_greedy_parse():
|
||||
nlp = Language()
|
||||
ner = nlp.add_pipe("beam_ner")
|
||||
for label in ["A", "B", "C"]:
|
||||
ner.add_label(label)
|
||||
nlp.initialize()
|
||||
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||
ner.greedy_parse([doc])
|
||||
assert "D" in ner.labels
|
||||
|
||||
|
||||
def test_ner_labels_added_implicitly_on_update():
|
||||
nlp = Language()
|
||||
ner = nlp.add_pipe("ner")
|
||||
|
|
|
@ -303,14 +303,14 @@ def test_issue4313():
|
|||
doc = nlp("What do you think about Apple ?")
|
||||
assert len(ner.labels) == 1
|
||||
assert "SOME_LABEL" in ner.labels
|
||||
ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary...
|
||||
apple_ent = Span(doc, 5, 6, label="MY_ORG")
|
||||
doc.ents = list(doc.ents) + [apple_ent]
|
||||
|
||||
# ensure the beam_parse still works with the new label
|
||||
docs = [doc]
|
||||
ner = nlp.get_pipe("beam_ner")
|
||||
ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
|
||||
assert len(ner.labels) == 2
|
||||
assert "MY_ORG" in ner.labels
|
||||
|
||||
|
||||
def test_issue4348():
|
||||
|
|
Loading…
Reference in New Issue