From c7adca58a9c85423d97f859c06b6c92e8aee35ab Mon Sep 17 00:00:00 2001 From: ines Date: Sun, 16 Apr 2017 16:55:01 +0200 Subject: [PATCH] Tidy up example and only save/test if output_directory is not None --- examples/training/train_new_entity_type.py | 31 ++++++++++------------ 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index af98ef583..cbe2963d3 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -1,22 +1,16 @@ from __future__ import unicode_literals, print_function -import json -import pathlib + import random +from pathlib import Path import spacy from spacy.pipeline import EntityRecognizer from spacy.gold import GoldParse from spacy.tagger import Tagger - -try: - unicode -except: - unicode = str - def train_ner(nlp, train_data, output_dir): - # Add new words to vocab. + # Add new words to vocab for raw_text, _ in train_data: doc = nlp.make_doc(raw_text) for word in doc: @@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir): nlp.tagger(doc) loss = nlp.entity.update(doc, gold) nlp.end_training() - nlp.save_to_directory(output_dir) + if output_dir: + nlp.save_to_directory(output_dir) def main(model_name, output_directory=None): nlp = spacy.load(model_name) + if output_directory is not None: + output_directory = Path(output_directory) train_data = [ ( @@ -55,18 +52,18 @@ def main(model_name, output_directory=None): ) ] nlp.entity.add_label('ANIMAL') - if output_directory is not None: - output_directory = pathlib.Path(output_directory) ner = train_ner(nlp, train_data, output_directory) + # Test that the entity is recognized doc = nlp('Do you like horses?') for ent in doc.ents: print(ent.label_, ent.text) - nlp2 = spacy.load('en', path=output_directory) - nlp2.entity.add_label('ANIMAL') - doc2 = nlp2('Do you like horses?') - for ent in doc2.ents: - print(ent.label_, ent.text) + if output_directory: + nlp2 = spacy.load('en', path=output_directory) + nlp2.entity.add_label('ANIMAL') + doc2 = nlp2('Do you like horses?') + for ent in doc2.ents: + print(ent.label_, ent.text) if __name__ == '__main__':