mirror of https://github.com/explosion/spaCy.git
negative tag annotation (#8731)
* unit test to unlearn tag via negative annotation * bump thinc to 8.0.8
This commit is contained in:
parent
0e4b96c97e
commit
83e27d262e
|
@ -5,7 +5,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.0.7,<8.1.0",
|
||||
"thinc>=8.0.8,<8.1.0",
|
||||
"blis>=0.4.0,<0.8.0",
|
||||
"pathy",
|
||||
"numpy>=1.15.0",
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
spacy-legacy>=3.0.7,<3.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.7,<8.1.0
|
||||
thinc>=8.0.8,<8.1.0
|
||||
blis>=0.4.0,<0.8.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
|
|
|
@ -37,14 +37,14 @@ setup_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.0.7,<8.1.0
|
||||
thinc>=8.0.8,<8.1.0
|
||||
install_requires =
|
||||
# Our libraries
|
||||
spacy-legacy>=3.0.7,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.7,<8.1.0
|
||||
thinc>=8.0.8,<8.1.0
|
||||
blis>=0.4.0,<0.8.0
|
||||
wasabi>=0.8.1,<1.1.0
|
||||
srsly>=2.4.1,<3.0.0
|
||||
|
|
|
@ -222,7 +222,7 @@ class Tagger(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Tagger.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!")
|
||||
# Convert empty tag "" to missing value None so that both misaligned
|
||||
# tokens and tokens with missing annotation have the default missing
|
||||
# value None.
|
||||
|
|
|
@ -182,6 +182,17 @@ def test_overfitting_IO():
|
|||
assert_equal(batch_deps_1, batch_deps_2)
|
||||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
# Try to unlearn the first 'N' tag with negative annotation
|
||||
neg_ex = Example.from_dict(nlp.make_doc(test_text), {"tags": ["!N", "V", "J", "N"]})
|
||||
|
||||
for i in range(20):
|
||||
losses = {}
|
||||
nlp.update([neg_ex], sgd=optimizer, losses=losses)
|
||||
|
||||
# test the "untrained" tag
|
||||
doc3 = nlp(test_text)
|
||||
assert doc3[0].tag_ != "N"
|
||||
|
||||
|
||||
def test_tagger_requires_labels():
|
||||
nlp = English()
|
||||
|
|
Loading…
Reference in New Issue