Update train_new_entity_type example to use disable_pipes

This commit is contained in:
ines 2017-10-25 14:56:53 +02:00
parent 6a00de4f77
commit 615c315d70
1 changed files with 96 additions and 78 deletions

View File

@ -21,103 +21,121 @@ After training your model, you can save it to a directory. We recommend
wrapping models as Python packages, for ease of deployment. wrapping models as Python packages, for ease of deployment.
For more details, see the documentation: For more details, see the documentation:
* Training the Named Entity Recognizer: https://spacy.io/docs/usage/train-ner * Training: https://alpha.spacy.io/usage/training
* Saving and loading models: https://spacy.io/docs/usage/saving-loading * NER: https://alpha.spacy.io/usage/linguistic-features#named-entities
Developed for: spaCy 1.7.6 Developed for: spaCy 2.0.0a18
Last updated for: spaCy 2.0.0a13 Last updated for: spaCy 2.0.0a18
""" """
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import random import random
from pathlib import Path from pathlib import Path
import random
import spacy import spacy
from spacy.gold import GoldParse, minibatch from spacy.gold import GoldParse, minibatch
from spacy.pipeline import NeuralEntityRecognizer from spacy.pipeline import NeuralEntityRecognizer
from spacy.pipeline import TokenVectorEncoder
# new entity label
LABEL = 'ANIMAL'
# training data
TRAIN_DATA = [
("Horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')]),
("Do they bite?", []),
("horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')]),
("horses pretend to care about your feelings", [(0, 6, 'ANIMAL')]),
("they pretend to care about your feelings, those horses",
[(48, 54, 'ANIMAL')]),
("horses?", [(0, 6, 'ANIMAL')])
]
def main(model=None, new_model_name='animal', output_dir=None):
"""Set up the pipeline and entity recognizer, and train the new entity.
model (unicode): Model name to start off with. If None, a blank English
Language class is created.
new_model_name (unicode): Name of new model to create. Will be added to the
model meta and prefixed by the language code, e.g. 'en_animal'.
output_dir (unicode / Path): Optional output directory. If None, no model
will be saved.
"""
if model is not None:
nlp = spacy.load(model) # load existing spaCy model
print("Loaded model '%s'" % model)
else:
nlp = spacy.blank('en') # create blank Language class
print("Created blank 'en' model")
# Add entity recognizer to model if it's not in the pipeline
if 'ner' not in nlp.pipe_names:
nlp.add_pipe(NeuralEntityRecognizer(nlp.vocab))
ner = nlp.get_pipe('ner') # get entity recognizer
ner.add_label(LABEL) # add new entity label to entity recognizer
# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
with nlp.disable_pipes(*other_pipes) as disabled: # only train NER
random.seed(0)
optimizer = nlp.begin_training(lambda: [])
for itn in range(50):
losses = {}
gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA)
for batch in minibatch(gold_parses, size=3):
docs, golds = zip(*batch)
nlp.update(docs, golds, losses=losses, sgd=optimizer,
drop=0.35)
print(losses)
print(nlp.pipeline)
print(disabled.original_pipeline)
# test the trained model
test_text = 'Do you like horses?'
doc = nlp(test_text)
print("Entities in '%s'" % test_text)
for ent in doc.ents:
print(ent.label_, ent.text)
# save model to output directory
if output_dir is not None:
output_dir = Path(output_dir)
if not output_dir.exists():
output_dir.mkdir()
nlp.meta['name'] = new_model_name # rename model
nlp.to_disk(output_dir)
print("Saved model to", output_dir)
# test the saved model
print("Loading from", output_dir)
nlp2 = spacy.load(output_dir)
doc2 = nlp2(test_text)
for ent in doc2.ents:
print(ent.label_, ent.text)
def get_gold_parses(tokenizer, train_data): def get_gold_parses(tokenizer, train_data):
'''Shuffle and create GoldParse objects''' """Shuffle and create GoldParse objects.
tokenizer (Tokenizer): Tokenizer to processs the raw text.
train_data (list): The training data.
YIELDS (tuple): (doc, gold) tuples.
"""
random.shuffle(train_data) random.shuffle(train_data)
for raw_text, entity_offsets in train_data: for raw_text, entity_offsets in train_data:
doc = tokenizer(raw_text) doc = tokenizer(raw_text)
gold = GoldParse(doc, entities=entity_offsets) gold = GoldParse(doc, entities=entity_offsets)
yield doc, gold yield doc, gold
def train_ner(nlp, train_data, output_dir):
random.seed(0)
optimizer = nlp.begin_training(lambda: [])
nlp.meta['name'] = 'en_ent_animal'
for itn in range(50):
losses = {}
for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
docs, golds = zip(*batch)
nlp.update(docs, golds, losses=losses, sgd=optimizer, drop=0.35)
print(losses)
if not output_dir:
return
elif not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
def main(model_name, output_directory=None):
print("Creating initial model", model_name)
nlp = spacy.blank(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
train_data = [
(
"Horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')],
),
(
"Do they bite?",
[],
),
(
"horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')]
),
(
"horses pretend to care about your feelings",
[(0, 6, 'ANIMAL')]
),
(
"they pretend to care about your feelings, those horses",
[(48, 54, 'ANIMAL')]
),
(
"horses?",
[(0, 6, 'ANIMAL')]
)
]
nlp.add_pipe(TokenVectorEncoder(nlp.vocab))
ner = NeuralEntityRecognizer(nlp.vocab)
ner.add_label('ANIMAL')
nlp.add_pipe(ner)
train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
text = 'Do you like horses?'
print("Ents in 'Do you like horses?':")
doc = nlp(text)
for ent in doc.ents:
print(ent.label_, ent.text)
if output_directory:
print("Loading from", output_directory)
nlp2 = spacy.load(output_directory)
doc2 = nlp2('Do you like horses?')
for ent in doc2.ents:
print(ent.label_, ent.text)
if __name__ == '__main__': if __name__ == '__main__':
import plac import plac