mirror of https://github.com/explosion/spaCy.git
Fix bug in multi-task objective
This commit is contained in:
parent
2c9c8b8d72
commit
968dabdde4
|
@ -690,11 +690,7 @@ class MultitaskObjective(Tagger):
|
||||||
for i, gold in enumerate(golds):
|
for i, gold in enumerate(golds):
|
||||||
for j in range(len(docs[i])):
|
for j in range(len(docs[i])):
|
||||||
# Handes alignment for tokenization differences
|
# Handes alignment for tokenization differences
|
||||||
gold_idx = gold.cand_to_gold[j]
|
label = self.make_label(j, gold.words, gold.tags,
|
||||||
if gold_idx is None:
|
|
||||||
idx += 1
|
|
||||||
continue
|
|
||||||
label = self.make_label(gold_idx, gold.words, gold.tags,
|
|
||||||
gold.heads, gold.labels, gold.ents)
|
gold.heads, gold.labels, gold.ents)
|
||||||
if label is None or label not in self.labels:
|
if label is None or label not in self.labels:
|
||||||
correct[idx] = guesses[idx]
|
correct[idx] = guesses[idx]
|
||||||
|
@ -749,6 +745,8 @@ class MultitaskObjective(Tagger):
|
||||||
of gold data. You can pass cache=False if you know the cache will
|
of gold data. You can pass cache=False if you know the cache will
|
||||||
do the wrong thing.
|
do the wrong thing.
|
||||||
'''
|
'''
|
||||||
|
assert len(words) == len(heads)
|
||||||
|
assert target < len(words), (target, len(words))
|
||||||
if cache:
|
if cache:
|
||||||
if id(heads) in _cache:
|
if id(heads) in _cache:
|
||||||
return _cache[id(heads)][target]
|
return _cache[id(heads)][target]
|
||||||
|
@ -783,8 +781,6 @@ class MultitaskObjective(Tagger):
|
||||||
return sent_tags[target]
|
return sent_tags[target]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SimilarityHook(Pipe):
|
class SimilarityHook(Pipe):
|
||||||
"""
|
"""
|
||||||
Experimental: A pipeline component to install a hook for supervised
|
Experimental: A pipeline component to install a hook for supervised
|
||||||
|
|
Loading…
Reference in New Issue