spaCy/spacy/tests/training/test_rehearse.py

169 lines
5.5 KiB
Python
Raw Normal View History

import pytest
import spacy
from typing import List
from spacy.training import Example
TRAIN_DATA = [
(
'Who is Kofi Annan?',
{
'entities': [(7, 18, 'PERSON')],
'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'],
'heads': [1, 1, 3, 1, 1],
'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'],
'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'],
'cats': {'question': 1.0}
}
),
(
'Who is Steve Jobs?',
{
'entities': [(7, 17, 'PERSON')],
'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'],
'heads': [1, 1, 3, 1, 1],
'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'],
'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'],
'cats': {'question': 1.0}
}
),
(
'Bob is a nice person.',
{
'entities': [(0, 3, 'PERSON')],
'tags': ['PROPN', 'AUX', 'DET', 'ADJ', 'NOUN', 'PUNCT'],
'heads': [1, 1, 4, 4, 1, 1],
'deps': ['nsubj', 'ROOT', 'det', 'amod', 'attr', 'punct'],
'morphs': ['Number=Sing', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Definite=Ind|PronType=Art', 'Degree=Pos', 'Number=Sing', 'PunctType=Peri'],
'cats': {'statement': 1.0}
},
),
(
'Hi Anil, how are you?',
{
'entities': [(3, 7, 'PERSON')],
'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'],
'deps': ['intj', 'npadvmod', 'punct', 'advmod', 'ROOT', 'nsubj', 'punct'],
'heads': [4, 0, 4, 4, 4, 4, 4],
'morphs': ['', 'Number=Sing', 'PunctType=Comm', '', 'Mood=Ind|Tense=Pres|VerbForm=Fin', 'Case=Nom|Person=2|PronType=Prs', 'PunctType=Peri'],
'cats': {'greeting': 1.0, 'question': 1.0}
}
),
(
'I like London and Berlin.',
{
'entities': [(7, 13, 'LOC'), (18, 24, 'LOC')],
'tags': ['PROPN', 'VERB', 'PROPN', 'CCONJ', 'PROPN', 'PUNCT'],
'deps': ['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct'],
'heads': [1, 1, 1, 2, 2, 1],
'morphs': ['Case=Nom|Number=Sing|Person=1|PronType=Prs', 'Tense=Pres|VerbForm=Fin', 'Number=Sing', 'ConjType=Cmp', 'Number=Sing', 'PunctType=Peri'],
'cats': {'statement': 1.0}
}
)
]
REHEARSE_DATA = [
(
'Hi Anil',
{
'entities': [(3, 7, 'PERSON')],
'tags': ['INTJ', 'PROPN'],
'deps': ['ROOT', 'npadvmod'],
'heads': [0, 0],
'morphs': ['', 'Number=Sing'],
'cats': {'greeting': 1.0}
}
),
(
'Hi Ravish, how you doing?',
{
'entities': [(3, 9, 'PERSON')],
'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'],
'deps': ['intj', 'ROOT', 'punct', 'advmod', 'nsubj', 'advcl', 'punct'],
'heads': [1, 1, 1, 5, 5, 1, 1],
'morphs': ['', 'VerbForm=Inf', 'PunctType=Comm', '', 'Case=Nom|Person=2|PronType=Prs', 'Aspect=Prog|Tense=Pres|VerbForm=Part', 'PunctType=Peri'],
'cats': {'greeting': 1.0, 'question': 1.0}
}
),
# UTENSIL new label
(
'Natasha bought new forks.',
{
'entities': [(0, 7, 'PERSON'), (19, 24, 'UTENSIL')],
'tags': ['PROPN', 'VERB', 'ADJ', 'NOUN', 'PUNCT'],
'deps': ['nsubj', 'ROOT', 'amod', 'dobj', 'punct'],
'heads': [1, 1, 3, 1, 1],
'morphs': ['Number=Sing', 'Tense=Past|VerbForm=Fin', 'Degree=Pos', 'Number=Plur', 'PunctType=Peri'],
'cats': {'statement': 1.0}
}
)
]
def _add_ner_label(ner, data):
for _, annotations in data:
for ent in annotations['entities']:
ner.add_label(ent[2])
def _add_tagger_label(tagger, data):
for _, annotations in data:
for tag in annotations['tags']:
tagger.add_label(tag)
def _add_parser_label(parser, data):
for _, annotations in data:
for dep in annotations['deps']:
parser.add_label(dep)
def _add_textcat_label(textcat, data):
for _, annotations in data:
for cat in annotations['cats']:
textcat.add_label(cat)
def _optimize(
nlp,
component: str,
data: List,
rehearse: bool
):
"""Run either train or rehearse."""
pipe = nlp.get_pipe(component)
if component == 'ner':
_add_ner_label(pipe, data)
elif component == 'tagger':
_add_tagger_label(pipe, data)
elif component == 'parser':
_add_tagger_label(pipe, data)
elif component == 'textcat_multilabel':
_add_textcat_label(pipe, data)
else:
raise NotImplementedError
if rehearse:
optimizer = nlp.resume_training()
else:
optimizer = nlp.initialize()
for _ in range(5):
for text, annotation in data:
doc = nlp.make_doc(text)
example = Example.from_dict(doc, annotation)
if rehearse:
nlp.rehearse([example], sgd=optimizer)
else:
nlp.update([example], sgd=optimizer)
return nlp
@pytest.mark.parametrize("component", ['ner', 'tagger', 'parser', 'textcat_multilabel'])
def test_rehearse(component):
nlp = spacy.blank("en")
nlp.add_pipe(component)
nlp = _optimize(nlp, component, TRAIN_DATA, False)
_optimize(nlp, component, REHEARSE_DATA, True)