mirror of https://github.com/explosion/spaCy.git
Auto-format code with black (#11687)
Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
This commit is contained in:
parent
fb280001cc
commit
84d9cb6b38
|
@ -231,7 +231,7 @@ def test_tok2vec_listener_callback():
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec_listener_overfitting():
|
def test_tok2vec_listener_overfitting():
|
||||||
""" Test that a pipeline with a listener properly overfits, even if 'tok2vec' is in the annotating components """
|
"""Test that a pipeline with a listener properly overfits, even if 'tok2vec' is in the annotating components"""
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -264,7 +264,7 @@ def test_tok2vec_listener_overfitting():
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec_frozen_not_annotating():
|
def test_tok2vec_frozen_not_annotating():
|
||||||
""" Test that a pipeline with a frozen tok2vec raises an error when the tok2vec is not annotating """
|
"""Test that a pipeline with a frozen tok2vec raises an error when the tok2vec is not annotating"""
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -274,12 +274,16 @@ def test_tok2vec_frozen_not_annotating():
|
||||||
|
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
losses = {}
|
losses = {}
|
||||||
with pytest.raises(ValueError, match=r"the tok2vec embedding layer is not updated"):
|
with pytest.raises(
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses, exclude=["tok2vec"])
|
ValueError, match=r"the tok2vec embedding layer is not updated"
|
||||||
|
):
|
||||||
|
nlp.update(
|
||||||
|
train_examples, sgd=optimizer, losses=losses, exclude=["tok2vec"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec_frozen_overfitting():
|
def test_tok2vec_frozen_overfitting():
|
||||||
""" Test that a pipeline with a frozen & annotating tok2vec can still overfit """
|
"""Test that a pipeline with a frozen & annotating tok2vec can still overfit"""
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -289,7 +293,13 @@ def test_tok2vec_frozen_overfitting():
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses, exclude=["tok2vec"], annotates=["tok2vec"])
|
nlp.update(
|
||||||
|
train_examples,
|
||||||
|
sgd=optimizer,
|
||||||
|
losses=losses,
|
||||||
|
exclude=["tok2vec"],
|
||||||
|
annotates=["tok2vec"],
|
||||||
|
)
|
||||||
assert losses["tagger"] < 0.0001
|
assert losses["tagger"] < 0.0001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
|
|
Loading…
Reference in New Issue