mirror of https://github.com/explosion/spaCy.git
v2.1 introduced a regression when deserializing the parser after parser.add_label() had been called. The code around the class mapping is pretty confusing currently, as it was written to accommodate backwards model compatibility. It needs to be revised when the models are next retrained. Closes #3433
This commit is contained in:
parent
444a3abfe5
commit
d9a07a7f6e
|
@ -574,11 +574,12 @@ cdef class Parser:
|
|||
cfg.setdefault('min_action_freq', 30)
|
||||
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
|
||||
min_freq=cfg.get('min_action_freq', 30))
|
||||
previous_labels = dict(self.moves.labels)
|
||||
for action, labels in self.moves.labels.items():
|
||||
actions.setdefault(action, {})
|
||||
for label, freq in labels.items():
|
||||
if label not in actions[action]:
|
||||
actions[action][label] = freq
|
||||
self.moves.initialize_actions(actions)
|
||||
for action, label_freqs in previous_labels.items():
|
||||
for label in label_freqs:
|
||||
self.moves.add_action(action, label)
|
||||
cfg.setdefault('token_vector_width', 96)
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
|
|
|
@ -33,7 +33,7 @@ def _train_parser(parser):
|
|||
parser.begin_training([], **parser.cfg)
|
||||
sgd = Adam(NumpyOps(), 0.001)
|
||||
|
||||
for i in range(10):
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||
|
@ -43,21 +43,7 @@ def _train_parser(parser):
|
|||
|
||||
def test_add_label(parser):
|
||||
parser = _train_parser(parser)
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
doc = parser(doc)
|
||||
assert doc[0].head.i == 1
|
||||
assert doc[0].dep_ == "left"
|
||||
assert doc[1].head.i == 1
|
||||
assert doc[2].head.i == 3
|
||||
assert doc[2].head.i == 3
|
||||
parser.add_label("right")
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
doc = parser(doc)
|
||||
assert doc[0].head.i == 1
|
||||
assert doc[0].dep_ == "left"
|
||||
assert doc[1].head.i == 1
|
||||
assert doc[2].head.i == 3
|
||||
assert doc[2].head.i == 3
|
||||
sgd = Adam(NumpyOps(), 0.001)
|
||||
for i in range(10):
|
||||
losses = {}
|
||||
|
@ -72,7 +58,6 @@ def test_add_label(parser):
|
|||
assert doc[2].dep_ == "left"
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_add_label_deserializes_correctly():
|
||||
ner1 = EntityRecognizer(Vocab())
|
||||
ner1.add_label("C")
|
||||
|
|
Loading…
Reference in New Issue