mirror of https://github.com/explosion/spaCy.git
Tidy up example and only save/test if output_directory is not None
This commit is contained in:
parent
d10bd0eaf9
commit
c7adca58a9
|
@ -1,22 +1,16 @@
|
|||
from __future__ import unicode_literals, print_function
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import spacy
|
||||
from spacy.pipeline import EntityRecognizer
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.tagger import Tagger
|
||||
|
||||
|
||||
try:
|
||||
unicode
|
||||
except:
|
||||
unicode = str
|
||||
|
||||
|
||||
def train_ner(nlp, train_data, output_dir):
|
||||
# Add new words to vocab.
|
||||
# Add new words to vocab
|
||||
for raw_text, _ in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
for word in doc:
|
||||
|
@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir):
|
|||
nlp.tagger(doc)
|
||||
loss = nlp.entity.update(doc, gold)
|
||||
nlp.end_training()
|
||||
nlp.save_to_directory(output_dir)
|
||||
if output_dir:
|
||||
nlp.save_to_directory(output_dir)
|
||||
|
||||
|
||||
def main(model_name, output_directory=None):
|
||||
nlp = spacy.load(model_name)
|
||||
if output_directory is not None:
|
||||
output_directory = Path(output_directory)
|
||||
|
||||
train_data = [
|
||||
(
|
||||
|
@ -55,18 +52,18 @@ def main(model_name, output_directory=None):
|
|||
)
|
||||
]
|
||||
nlp.entity.add_label('ANIMAL')
|
||||
if output_directory is not None:
|
||||
output_directory = pathlib.Path(output_directory)
|
||||
ner = train_ner(nlp, train_data, output_directory)
|
||||
|
||||
# Test that the entity is recognized
|
||||
doc = nlp('Do you like horses?')
|
||||
for ent in doc.ents:
|
||||
print(ent.label_, ent.text)
|
||||
nlp2 = spacy.load('en', path=output_directory)
|
||||
nlp2.entity.add_label('ANIMAL')
|
||||
doc2 = nlp2('Do you like horses?')
|
||||
for ent in doc2.ents:
|
||||
print(ent.label_, ent.text)
|
||||
if output_directory:
|
||||
nlp2 = spacy.load('en', path=output_directory)
|
||||
nlp2.entity.add_label('ANIMAL')
|
||||
doc2 = nlp2('Do you like horses?')
|
||||
for ent in doc2.ents:
|
||||
print(ent.label_, ent.text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue