Improve train_new_entity_type example

This commit is contained in:
Matthew Honnibal 2018-03-29 20:26:41 +02:00
parent ea2af94cd9
commit 68ad366935
1 changed files with 8 additions and 2 deletions

View File

@ -81,7 +81,6 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20):
else: else:
nlp = spacy.blank('en') # create blank Language class nlp = spacy.blank('en') # create blank Language class
print("Created blank 'en' model") print("Created blank 'en' model")
# Add entity recognizer to model if it's not in the pipeline # Add entity recognizer to model if it's not in the pipeline
# nlp.create_pipe works for built-ins that are registered with spaCy # nlp.create_pipe works for built-ins that are registered with spaCy
if 'ner' not in nlp.pipe_names: if 'ner' not in nlp.pipe_names:
@ -92,11 +91,18 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20):
ner = nlp.get_pipe('ner') ner = nlp.get_pipe('ner')
ner.add_label(LABEL) # add new entity label to entity recognizer ner.add_label(LABEL) # add new entity label to entity recognizer
if model is None:
optimizer = nlp.begin_training()
else:
# Note that 'begin_training' initializes the models, so it'll zero out
# existing entity types.
optimizer = nlp.entity.create_optimizer()
# get names of other pipes to disable them during training # get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner'] other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
with nlp.disable_pipes(*other_pipes): # only train NER with nlp.disable_pipes(*other_pipes): # only train NER
optimizer = nlp.begin_training()
for itn in range(n_iter): for itn in range(n_iter):
random.shuffle(TRAIN_DATA) random.shuffle(TRAIN_DATA)
losses = {} losses = {}