From 68ad3669351448fc2f86e38f44406f04335dd72a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 29 Mar 2018 20:26:41 +0200 Subject: [PATCH] Improve train_new_entity_type example --- examples/training/train_new_entity_type.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index a0edde45c..7e1b19c0e 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -81,7 +81,6 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20): else: nlp = spacy.blank('en') # create blank Language class print("Created blank 'en' model") - # Add entity recognizer to model if it's not in the pipeline # nlp.create_pipe works for built-ins that are registered with spaCy if 'ner' not in nlp.pipe_names: @@ -92,11 +91,18 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20): ner = nlp.get_pipe('ner') ner.add_label(LABEL) # add new entity label to entity recognizer + if model is None: + optimizer = nlp.begin_training() + else: + # Note that 'begin_training' initializes the models, so it'll zero out + # existing entity types. + optimizer = nlp.entity.create_optimizer() + + # get names of other pipes to disable them during training other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner'] with nlp.disable_pipes(*other_pipes): # only train NER - optimizer = nlp.begin_training() for itn in range(n_iter): random.shuffle(TRAIN_DATA) losses = {}