From a88a7deffed5fcfd95bcbd10efd35c0a95c01a30 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 23 Jul 2017 00:33:43 +0200 Subject: [PATCH] Five save/load of textcat config --- spacy/pipeline.pyx | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index e7e6dcdfc..8c3ecacd6 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -109,7 +109,8 @@ class BaseThincComponent(object): def to_disk(self, path, **exclude): serialize = OrderedDict(( ('model', lambda p: p.open('wb').write(self.model.to_bytes())), - ('vocab', lambda p: self.vocab.to_disk(p)) + ('vocab', lambda p: self.vocab.to_disk(p)), + ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))) )) util.to_disk(path, serialize, exclude) @@ -118,7 +119,8 @@ class BaseThincComponent(object): self.model = self.Model() deserialize = OrderedDict(( ('model', lambda p: self.model.from_bytes(p.open('rb').read())), - ('vocab', lambda p: self.vocab.from_disk(p)) + ('vocab', lambda p: self.vocab.from_disk(p)), + ('cfg', lambda p: self.cfg.update(ujson.load(p.open()))), )) util.from_disk(path, deserialize, exclude) return self @@ -383,6 +385,7 @@ class NeuralTagger(BaseThincComponent): use_bin_type=True, encoding='utf8'))), ('model', lambda p: p.open('wb').write(self.model.to_bytes())), + ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))) )) util.to_disk(path, serialize, exclude) @@ -405,6 +408,7 @@ class NeuralTagger(BaseThincComponent): ('vocab', lambda p: self.vocab.from_disk(p)), ('tag_map', load_tag_map), ('model', load_model), + ('cfg', lambda p: self.cfg.update(ujson.load(p.open()))), )) util.from_disk(path, deserialize, exclude) return self @@ -523,7 +527,15 @@ class TextCategorizer(BaseThincComponent): def __init__(self, vocab, model=True, **cfg): self.vocab = vocab self.model = model - self.labels = cfg.get('labels', ['LABEL']) + self.cfg = cfg + + @property + def labels(self): + return self.cfg.get('labels', ['LABEL']) + + @labels.setter + def labels(self, value): + self.cfg['labels'] = value def __call__(self, doc): scores = self.predict([doc])