Tidy up example and only save/test if output_directory is not None

This commit is contained in:
ines 2017-04-16 16:55:01 +02:00
parent d10bd0eaf9
commit c7adca58a9
1 changed files with 14 additions and 17 deletions

View File

@ -1,22 +1,16 @@
from __future__ import unicode_literals, print_function
import json
import pathlib
import random
from pathlib import Path
import spacy
from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse
from spacy.tagger import Tagger
try:
unicode
except:
unicode = str
def train_ner(nlp, train_data, output_dir):
# Add new words to vocab.
# Add new words to vocab
for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text)
for word in doc:
@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir):
nlp.tagger(doc)
loss = nlp.entity.update(doc, gold)
nlp.end_training()
nlp.save_to_directory(output_dir)
if output_dir:
nlp.save_to_directory(output_dir)
def main(model_name, output_directory=None):
nlp = spacy.load(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
train_data = [
(
@ -55,18 +52,18 @@ def main(model_name, output_directory=None):
)
]
nlp.entity.add_label('ANIMAL')
if output_directory is not None:
output_directory = pathlib.Path(output_directory)
ner = train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
doc = nlp('Do you like horses?')
for ent in doc.ents:
print(ent.label_, ent.text)
nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL')
doc2 = nlp2('Do you like horses?')
for ent in doc2.ents:
print(ent.label_, ent.text)
if output_directory:
nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL')
doc2 = nlp2('Do you like horses?')
for ent in doc2.ents:
print(ent.label_, ent.text)
if __name__ == '__main__':