mirror of https://github.com/explosion/spaCy.git
Update train_new_entity_type example to use disable_pipes
This commit is contained in:
parent
6a00de4f77
commit
615c315d70
|
@ -21,103 +21,121 @@ After training your model, you can save it to a directory. We recommend
|
||||||
wrapping models as Python packages, for ease of deployment.
|
wrapping models as Python packages, for ease of deployment.
|
||||||
|
|
||||||
For more details, see the documentation:
|
For more details, see the documentation:
|
||||||
* Training the Named Entity Recognizer: https://spacy.io/docs/usage/train-ner
|
* Training: https://alpha.spacy.io/usage/training
|
||||||
* Saving and loading models: https://spacy.io/docs/usage/saving-loading
|
* NER: https://alpha.spacy.io/usage/linguistic-features#named-entities
|
||||||
|
|
||||||
Developed for: spaCy 1.7.6
|
Developed for: spaCy 2.0.0a18
|
||||||
Last updated for: spaCy 2.0.0a13
|
Last updated for: spaCy 2.0.0a18
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.gold import GoldParse, minibatch
|
from spacy.gold import GoldParse, minibatch
|
||||||
from spacy.pipeline import NeuralEntityRecognizer
|
from spacy.pipeline import NeuralEntityRecognizer
|
||||||
from spacy.pipeline import TokenVectorEncoder
|
|
||||||
|
|
||||||
|
# new entity label
|
||||||
|
LABEL = 'ANIMAL'
|
||||||
|
|
||||||
|
# training data
|
||||||
|
TRAIN_DATA = [
|
||||||
|
("Horses are too tall and they pretend to care about your feelings",
|
||||||
|
[(0, 6, 'ANIMAL')]),
|
||||||
|
|
||||||
|
("Do they bite?", []),
|
||||||
|
|
||||||
|
("horses are too tall and they pretend to care about your feelings",
|
||||||
|
[(0, 6, 'ANIMAL')]),
|
||||||
|
|
||||||
|
("horses pretend to care about your feelings", [(0, 6, 'ANIMAL')]),
|
||||||
|
|
||||||
|
("they pretend to care about your feelings, those horses",
|
||||||
|
[(48, 54, 'ANIMAL')]),
|
||||||
|
|
||||||
|
("horses?", [(0, 6, 'ANIMAL')])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main(model=None, new_model_name='animal', output_dir=None):
|
||||||
|
"""Set up the pipeline and entity recognizer, and train the new entity.
|
||||||
|
|
||||||
|
model (unicode): Model name to start off with. If None, a blank English
|
||||||
|
Language class is created.
|
||||||
|
new_model_name (unicode): Name of new model to create. Will be added to the
|
||||||
|
model meta and prefixed by the language code, e.g. 'en_animal'.
|
||||||
|
output_dir (unicode / Path): Optional output directory. If None, no model
|
||||||
|
will be saved.
|
||||||
|
"""
|
||||||
|
if model is not None:
|
||||||
|
nlp = spacy.load(model) # load existing spaCy model
|
||||||
|
print("Loaded model '%s'" % model)
|
||||||
|
else:
|
||||||
|
nlp = spacy.blank('en') # create blank Language class
|
||||||
|
print("Created blank 'en' model")
|
||||||
|
|
||||||
|
# Add entity recognizer to model if it's not in the pipeline
|
||||||
|
if 'ner' not in nlp.pipe_names:
|
||||||
|
nlp.add_pipe(NeuralEntityRecognizer(nlp.vocab))
|
||||||
|
|
||||||
|
ner = nlp.get_pipe('ner') # get entity recognizer
|
||||||
|
ner.add_label(LABEL) # add new entity label to entity recognizer
|
||||||
|
|
||||||
|
# get names of other pipes to disable them during training
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
|
||||||
|
with nlp.disable_pipes(*other_pipes) as disabled: # only train NER
|
||||||
|
random.seed(0)
|
||||||
|
optimizer = nlp.begin_training(lambda: [])
|
||||||
|
for itn in range(50):
|
||||||
|
losses = {}
|
||||||
|
gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA)
|
||||||
|
for batch in minibatch(gold_parses, size=3):
|
||||||
|
docs, golds = zip(*batch)
|
||||||
|
nlp.update(docs, golds, losses=losses, sgd=optimizer,
|
||||||
|
drop=0.35)
|
||||||
|
print(losses)
|
||||||
|
print(nlp.pipeline)
|
||||||
|
print(disabled.original_pipeline)
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = 'Do you like horses?'
|
||||||
|
doc = nlp(test_text)
|
||||||
|
print("Entities in '%s'" % test_text)
|
||||||
|
for ent in doc.ents:
|
||||||
|
print(ent.label_, ent.text)
|
||||||
|
|
||||||
|
# save model to output directory
|
||||||
|
if output_dir is not None:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
if not output_dir.exists():
|
||||||
|
output_dir.mkdir()
|
||||||
|
nlp.meta['name'] = new_model_name # rename model
|
||||||
|
nlp.to_disk(output_dir)
|
||||||
|
print("Saved model to", output_dir)
|
||||||
|
|
||||||
|
# test the saved model
|
||||||
|
print("Loading from", output_dir)
|
||||||
|
nlp2 = spacy.load(output_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
for ent in doc2.ents:
|
||||||
|
print(ent.label_, ent.text)
|
||||||
|
|
||||||
|
|
||||||
def get_gold_parses(tokenizer, train_data):
|
def get_gold_parses(tokenizer, train_data):
|
||||||
'''Shuffle and create GoldParse objects'''
|
"""Shuffle and create GoldParse objects.
|
||||||
|
|
||||||
|
tokenizer (Tokenizer): Tokenizer to processs the raw text.
|
||||||
|
train_data (list): The training data.
|
||||||
|
YIELDS (tuple): (doc, gold) tuples.
|
||||||
|
"""
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
for raw_text, entity_offsets in train_data:
|
for raw_text, entity_offsets in train_data:
|
||||||
doc = tokenizer(raw_text)
|
doc = tokenizer(raw_text)
|
||||||
gold = GoldParse(doc, entities=entity_offsets)
|
gold = GoldParse(doc, entities=entity_offsets)
|
||||||
yield doc, gold
|
yield doc, gold
|
||||||
|
|
||||||
|
|
||||||
def train_ner(nlp, train_data, output_dir):
|
|
||||||
random.seed(0)
|
|
||||||
optimizer = nlp.begin_training(lambda: [])
|
|
||||||
nlp.meta['name'] = 'en_ent_animal'
|
|
||||||
for itn in range(50):
|
|
||||||
losses = {}
|
|
||||||
for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
|
|
||||||
docs, golds = zip(*batch)
|
|
||||||
nlp.update(docs, golds, losses=losses, sgd=optimizer, drop=0.35)
|
|
||||||
print(losses)
|
|
||||||
if not output_dir:
|
|
||||||
return
|
|
||||||
elif not output_dir.exists():
|
|
||||||
output_dir.mkdir()
|
|
||||||
nlp.to_disk(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model_name, output_directory=None):
|
|
||||||
print("Creating initial model", model_name)
|
|
||||||
nlp = spacy.blank(model_name)
|
|
||||||
if output_directory is not None:
|
|
||||||
output_directory = Path(output_directory)
|
|
||||||
|
|
||||||
train_data = [
|
|
||||||
(
|
|
||||||
"Horses are too tall and they pretend to care about your feelings",
|
|
||||||
[(0, 6, 'ANIMAL')],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Do they bite?",
|
|
||||||
[],
|
|
||||||
),
|
|
||||||
|
|
||||||
(
|
|
||||||
"horses are too tall and they pretend to care about your feelings",
|
|
||||||
[(0, 6, 'ANIMAL')]
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"horses pretend to care about your feelings",
|
|
||||||
[(0, 6, 'ANIMAL')]
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"they pretend to care about your feelings, those horses",
|
|
||||||
[(48, 54, 'ANIMAL')]
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"horses?",
|
|
||||||
[(0, 6, 'ANIMAL')]
|
|
||||||
)
|
|
||||||
|
|
||||||
]
|
|
||||||
nlp.add_pipe(TokenVectorEncoder(nlp.vocab))
|
|
||||||
ner = NeuralEntityRecognizer(nlp.vocab)
|
|
||||||
ner.add_label('ANIMAL')
|
|
||||||
nlp.add_pipe(ner)
|
|
||||||
train_ner(nlp, train_data, output_directory)
|
|
||||||
|
|
||||||
# Test that the entity is recognized
|
|
||||||
text = 'Do you like horses?'
|
|
||||||
print("Ents in 'Do you like horses?':")
|
|
||||||
doc = nlp(text)
|
|
||||||
for ent in doc.ents:
|
|
||||||
print(ent.label_, ent.text)
|
|
||||||
if output_directory:
|
|
||||||
print("Loading from", output_directory)
|
|
||||||
nlp2 = spacy.load(output_directory)
|
|
||||||
doc2 = nlp2('Do you like horses?')
|
|
||||||
for ent in doc2.ents:
|
|
||||||
print(ent.label_, ent.text)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import plac
|
import plac
|
||||||
|
|
Loading…
Reference in New Issue