diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py
index af98ef583..cbe2963d3 100644
--- a/examples/training/train_new_entity_type.py
+++ b/examples/training/train_new_entity_type.py
@@ -1,22 +1,16 @@
 from __future__ import unicode_literals, print_function
-import json
-import pathlib
+
 import random
+from pathlib import Path
 
 import spacy
 from spacy.pipeline import EntityRecognizer
 from spacy.gold import GoldParse
 from spacy.tagger import Tagger
 
- 
-try:
-    unicode
-except:
-    unicode = str
-
 
 def train_ner(nlp, train_data, output_dir):
-    # Add new words to vocab.
+    # Add new words to vocab
     for raw_text, _ in train_data:
         doc = nlp.make_doc(raw_text)
         for word in doc:
@@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir):
             nlp.tagger(doc)
             loss = nlp.entity.update(doc, gold)
     nlp.end_training()
-    nlp.save_to_directory(output_dir)
+    if output_dir:
+        nlp.save_to_directory(output_dir)
 
 
 def main(model_name, output_directory=None):
     nlp = spacy.load(model_name)
+    if output_directory is not None:
+        output_directory = Path(output_directory)
 
     train_data = [
         (
@@ -55,18 +52,18 @@ def main(model_name, output_directory=None):
         )
     ]
     nlp.entity.add_label('ANIMAL')
-    if output_directory is not None:
-        output_directory = pathlib.Path(output_directory)
     ner = train_ner(nlp, train_data, output_directory)
 
+    # Test that the entity is recognized
     doc = nlp('Do you like horses?')
     for ent in doc.ents:
         print(ent.label_, ent.text)
-    nlp2 = spacy.load('en', path=output_directory)
-    nlp2.entity.add_label('ANIMAL')
-    doc2 = nlp2('Do you like horses?')
-    for ent in doc2.ents:
-        print(ent.label_, ent.text)
+    if output_directory:
+        nlp2 = spacy.load('en', path=output_directory)
+        nlp2.entity.add_label('ANIMAL')
+        doc2 = nlp2('Do you like horses?')
+        for ent in doc2.ents:
+            print(ent.label_, ent.text)
 
 
 if __name__ == '__main__':