diff --git a/spacy/language.py b/spacy/language.py index 25bfb9e08..43bebd71d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -209,46 +209,34 @@ class Language(object): lang = None @classmethod - @contextmanager - def train(cls, path, gold_tuples, *configs): - if isinstance(path, basestring): - path = pathlib.Path(path) - tagger_cfg, parser_cfg, entity_cfg = configs - dep_model_dir = path / 'deps' - ner_model_dir = path / 'ner' - pos_model_dir = path / 'pos' - if dep_model_dir.exists(): - shutil.rmtree(str(dep_model_dir)) - if ner_model_dir.exists(): - shutil.rmtree(str(ner_model_dir)) - if pos_model_dir.exists(): - shutil.rmtree(str(pos_model_dir)) - dep_model_dir.mkdir() - ner_model_dir.mkdir() - pos_model_dir.mkdir() + def setup_directory(cls, path, **configs): + for name, config in configs.items(): + directory = path / name + if directory.exists(): + shutil.rmtree(str(directory)) + directory.mkdir() + with (directory / 'config.json').open('wb') as file_: + data = ujson.dumps(config, indent=2) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data) + if not (path / 'vocab').exists(): + (path / 'vocab').mkdir() + @classmethod + @contextmanager + def train(cls, path, gold_tuples, **configs): if parser_cfg['pseudoprojective']: # preprocess training data here before ArcEager.get_labels() is called gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) - parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) - entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) + for subdir in ('deps', 'ner', 'pos'): + if subdir not in configs: + configs[subdir] = {} + configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) + configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) - with (dep_model_dir / 'config.json').open('wb') as file_: - data = ujson.dumps(parser_cfg) - if isinstance(data, unicode): - data = data.encode('utf8') - file_.write(data) - with (ner_model_dir / 'config.json').open('wb') as file_: - data = ujson.dumps(entity_cfg) - if isinstance(data, unicode): - data = data.encode('utf8') - file_.write(data) - with (pos_model_dir / 'config.json').open('wb') as file_: - data = ujson.dumps(tagger_cfg) - if isinstance(data, unicode): - data = data.encode('utf8') - file_.write(data) + cls.setup_directory(path, **configs) self = cls( path=path, @@ -269,7 +257,9 @@ class Language(object): self.entity = self.Defaults.create_entity(self) self.pipeline = self.Defaults.create_pipeline(self) yield Trainer(self, gold_tuples) - self.end_training(path=path) + self.end_training() + self.save_to_directory(path, deps=self.parser.cfg, ner=self.entity.cfg, + pos=self.tagger.cfg) def __init__(self, **overrides): if 'data_dir' in overrides and 'path' not in overrides: @@ -373,51 +363,36 @@ class Language(object): for doc in stream: yield doc - def end_training(self, path=None): - if path is None: - path = self.path - elif isinstance(path, basestring): - path = pathlib.Path(path) - - if self.tagger: - self.tagger.model.end_training() - self.tagger.model.dump(str(path / 'pos' / 'model')) - if self.parser: - self.parser.model.end_training() - self.parser.model.dump(str(path / 'deps' / 'model')) - if self.entity: - self.entity.model.end_training() - self.entity.model.dump(str(path / 'ner' / 'model')) + def save_to_directory(self, path): + configs = { + 'pos': self.tagger.cfg if self.tagger else {}, + 'deps': self.parser.cfg if self.parser else {}, + 'ner': self.entity.cfg if self.entity else {}, + } + self.setup_directory(path, **configs) + strings_loc = path / 'vocab' / 'strings.json' with strings_loc.open('w', encoding='utf8') as file_: self.vocab.strings.dump(file_) self.vocab.dump(path / 'vocab' / 'lexemes.bin') - + # TODO: Word vectors? if self.tagger: - tagger_freqs = list(self.tagger.freqs[TAG].items()) - else: - tagger_freqs = [] + self.tagger.model.dump(str(path / 'pos' / 'model')) if self.parser: - dep_freqs = list(self.parser.moves.freqs[DEP].items()) - head_freqs = list(self.parser.moves.freqs[HEAD].items()) - else: - dep_freqs = [] - head_freqs = [] + self.parser.model.dump(str(path / 'deps' / 'model')) if self.entity: - entity_iob_freqs = list(self.entity.moves.freqs[ENT_IOB].items()) - entity_type_freqs = list(self.entity.moves.freqs[ENT_TYPE].items()) - else: - entity_iob_freqs = [] - entity_type_freqs = [] - with (path / 'vocab' / 'serializer.json').open('wb') as file_: - data = ujson.dumps([ - (TAG, tagger_freqs), - (DEP, dep_freqs), - (ENT_IOB, entity_iob_freqs), - (ENT_TYPE, entity_type_freqs), - (HEAD, head_freqs) - ]) - if isinstance(data, unicode): - data = data.encode('utf8') - file_.write(data) + self.entity.model.dump(str(path / 'ner' / 'model')) + + def end_training(self, path=None): + if self.tagger: + self.tagger.model.end_training() + if self.parser: + self.parser.model.end_training() + if self.entity: + self.entity.model.end_training() + # NB: This is slightly different from before --- we no longer default + # to taking nlp.path + if path is not None: + self.save_to_directory(path) +