diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index a0edde45c..7e1b19c0e 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -81,7 +81,6 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20): else: nlp = spacy.blank('en') # create blank Language class print("Created blank 'en' model") - # Add entity recognizer to model if it's not in the pipeline # nlp.create_pipe works for built-ins that are registered with spaCy if 'ner' not in nlp.pipe_names: @@ -92,11 +91,18 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=20): ner = nlp.get_pipe('ner') ner.add_label(LABEL) # add new entity label to entity recognizer + if model is None: + optimizer = nlp.begin_training() + else: + # Note that 'begin_training' initializes the models, so it'll zero out + # existing entity types. + optimizer = nlp.entity.create_optimizer() + + # get names of other pipes to disable them during training other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner'] with nlp.disable_pipes(*other_pipes): # only train NER - optimizer = nlp.begin_training() for itn in range(n_iter): random.shuffle(TRAIN_DATA) losses = {}