negative tag annotation (#8731)

* unit test to unlearn tag via negative annotation

* bump thinc to 8.0.8
This commit is contained in:
Sofie Van Landeghem 2021-07-19 14:39:11 +02:00 committed by GitHub
parent 0e4b96c97e
commit 83e27d262e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 5 deletions

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.7,<8.1.0",
"thinc>=8.0.8,<8.1.0",
"blis>=0.4.0,<0.8.0",
"pathy",
"numpy>=1.15.0",

View File

@ -2,7 +2,7 @@
spacy-legacy>=3.0.7,<3.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.7,<8.1.0
thinc>=8.0.8,<8.1.0
blis>=0.4.0,<0.8.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0

View File

@ -37,14 +37,14 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.0.7,<8.1.0
thinc>=8.0.8,<8.1.0
install_requires =
# Our libraries
spacy-legacy>=3.0.7,<3.1.0
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.7,<8.1.0
thinc>=8.0.8,<8.1.0
blis>=0.4.0,<0.8.0
wasabi>=0.8.1,<1.1.0
srsly>=2.4.1,<3.0.0

View File

@ -222,7 +222,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#get_loss
"""
validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!")
# Convert empty tag "" to missing value None so that both misaligned
# tokens and tokens with missing annotation have the default missing
# value None.

View File

@ -182,6 +182,17 @@ def test_overfitting_IO():
assert_equal(batch_deps_1, batch_deps_2)
assert_equal(batch_deps_1, no_batch_deps)
# Try to unlearn the first 'N' tag with negative annotation
neg_ex = Example.from_dict(nlp.make_doc(test_text), {"tags": ["!N", "V", "J", "N"]})
for i in range(20):
losses = {}
nlp.update([neg_ex], sgd=optimizer, losses=losses)
# test the "untrained" tag
doc3 = nlp(test_text)
assert doc3[0].tag_ != "N"
def test_tagger_requires_labels():
nlp = English()