Tidy up example and only save/test if output_directory is not None

This commit is contained in:
ines 2017-04-16 16:55:01 +02:00
parent d10bd0eaf9
commit c7adca58a9
1 changed files with 14 additions and 17 deletions

View File

@ -1,22 +1,16 @@
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import json
import pathlib
import random import random
from pathlib import Path
import spacy import spacy
from spacy.pipeline import EntityRecognizer from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tagger import Tagger from spacy.tagger import Tagger
try:
unicode
except:
unicode = str
def train_ner(nlp, train_data, output_dir): def train_ner(nlp, train_data, output_dir):
# Add new words to vocab. # Add new words to vocab
for raw_text, _ in train_data: for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text) doc = nlp.make_doc(raw_text)
for word in doc: for word in doc:
@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir):
nlp.tagger(doc) nlp.tagger(doc)
loss = nlp.entity.update(doc, gold) loss = nlp.entity.update(doc, gold)
nlp.end_training() nlp.end_training()
nlp.save_to_directory(output_dir) if output_dir:
nlp.save_to_directory(output_dir)
def main(model_name, output_directory=None): def main(model_name, output_directory=None):
nlp = spacy.load(model_name) nlp = spacy.load(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
train_data = [ train_data = [
( (
@ -55,18 +52,18 @@ def main(model_name, output_directory=None):
) )
] ]
nlp.entity.add_label('ANIMAL') nlp.entity.add_label('ANIMAL')
if output_directory is not None:
output_directory = pathlib.Path(output_directory)
ner = train_ner(nlp, train_data, output_directory) ner = train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
doc = nlp('Do you like horses?') doc = nlp('Do you like horses?')
for ent in doc.ents: for ent in doc.ents:
print(ent.label_, ent.text) print(ent.label_, ent.text)
nlp2 = spacy.load('en', path=output_directory) if output_directory:
nlp2.entity.add_label('ANIMAL') nlp2 = spacy.load('en', path=output_directory)
doc2 = nlp2('Do you like horses?') nlp2.entity.add_label('ANIMAL')
for ent in doc2.ents: doc2 = nlp2('Do you like horses?')
print(ent.label_, ent.text) for ent in doc2.ents:
print(ent.label_, ent.text)
if __name__ == '__main__': if __name__ == '__main__':