Refactor Language.end_training, making new save_to_directory method

This commit is contained in:
Matthew Honnibal 2017-04-14 23:51:24 +02:00
parent 49e2de900e
commit 33ba5066eb
1 changed files with 50 additions and 75 deletions

View File

@ -209,46 +209,34 @@ class Language(object):
lang = None lang = None
@classmethod @classmethod
@contextmanager def setup_directory(cls, path, **configs):
def train(cls, path, gold_tuples, *configs): for name, config in configs.items():
if isinstance(path, basestring): directory = path / name
path = pathlib.Path(path) if directory.exists():
tagger_cfg, parser_cfg, entity_cfg = configs shutil.rmtree(str(directory))
dep_model_dir = path / 'deps' directory.mkdir()
ner_model_dir = path / 'ner' with (directory / 'config.json').open('wb') as file_:
pos_model_dir = path / 'pos' data = ujson.dumps(config, indent=2)
if dep_model_dir.exists(): if isinstance(data, unicode):
shutil.rmtree(str(dep_model_dir)) data = data.encode('utf8')
if ner_model_dir.exists(): file_.write(data)
shutil.rmtree(str(ner_model_dir)) if not (path / 'vocab').exists():
if pos_model_dir.exists(): (path / 'vocab').mkdir()
shutil.rmtree(str(pos_model_dir))
dep_model_dir.mkdir()
ner_model_dir.mkdir()
pos_model_dir.mkdir()
@classmethod
@contextmanager
def train(cls, path, gold_tuples, **configs):
if parser_cfg['pseudoprojective']: if parser_cfg['pseudoprojective']:
# preprocess training data here before ArcEager.get_labels() is called # preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) for subdir in ('deps', 'ner', 'pos'):
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) if subdir not in configs:
configs[subdir] = {}
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_: cls.setup_directory(path, **configs)
data = ujson.dumps(parser_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (ner_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(entity_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (pos_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(tagger_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
self = cls( self = cls(
path=path, path=path,
@ -269,7 +257,9 @@ class Language(object):
self.entity = self.Defaults.create_entity(self) self.entity = self.Defaults.create_entity(self)
self.pipeline = self.Defaults.create_pipeline(self) self.pipeline = self.Defaults.create_pipeline(self)
yield Trainer(self, gold_tuples) yield Trainer(self, gold_tuples)
self.end_training(path=path) self.end_training()
self.save_to_directory(path, deps=self.parser.cfg, ner=self.entity.cfg,
pos=self.tagger.cfg)
def __init__(self, **overrides): def __init__(self, **overrides):
if 'data_dir' in overrides and 'path' not in overrides: if 'data_dir' in overrides and 'path' not in overrides:
@ -373,51 +363,36 @@ class Language(object):
for doc in stream: for doc in stream:
yield doc yield doc
def end_training(self, path=None): def save_to_directory(self, path):
if path is None: configs = {
path = self.path 'pos': self.tagger.cfg if self.tagger else {},
elif isinstance(path, basestring): 'deps': self.parser.cfg if self.parser else {},
path = pathlib.Path(path) 'ner': self.entity.cfg if self.entity else {},
}
if self.tagger: self.setup_directory(path, **configs)
self.tagger.model.end_training()
self.tagger.model.dump(str(path / 'pos' / 'model'))
if self.parser:
self.parser.model.end_training()
self.parser.model.dump(str(path / 'deps' / 'model'))
if self.entity:
self.entity.model.end_training()
self.entity.model.dump(str(path / 'ner' / 'model'))
strings_loc = path / 'vocab' / 'strings.json' strings_loc = path / 'vocab' / 'strings.json'
with strings_loc.open('w', encoding='utf8') as file_: with strings_loc.open('w', encoding='utf8') as file_:
self.vocab.strings.dump(file_) self.vocab.strings.dump(file_)
self.vocab.dump(path / 'vocab' / 'lexemes.bin') self.vocab.dump(path / 'vocab' / 'lexemes.bin')
# TODO: Word vectors?
if self.tagger: if self.tagger:
tagger_freqs = list(self.tagger.freqs[TAG].items()) self.tagger.model.dump(str(path / 'pos' / 'model'))
else:
tagger_freqs = []
if self.parser: if self.parser:
dep_freqs = list(self.parser.moves.freqs[DEP].items()) self.parser.model.dump(str(path / 'deps' / 'model'))
head_freqs = list(self.parser.moves.freqs[HEAD].items())
else:
dep_freqs = []
head_freqs = []
if self.entity: if self.entity:
entity_iob_freqs = list(self.entity.moves.freqs[ENT_IOB].items()) self.entity.model.dump(str(path / 'ner' / 'model'))
entity_type_freqs = list(self.entity.moves.freqs[ENT_TYPE].items())
else: def end_training(self, path=None):
entity_iob_freqs = [] if self.tagger:
entity_type_freqs = [] self.tagger.model.end_training()
with (path / 'vocab' / 'serializer.json').open('wb') as file_: if self.parser:
data = ujson.dumps([ self.parser.model.end_training()
(TAG, tagger_freqs), if self.entity:
(DEP, dep_freqs), self.entity.model.end_training()
(ENT_IOB, entity_iob_freqs), # NB: This is slightly different from before --- we no longer default
(ENT_TYPE, entity_type_freqs), # to taking nlp.path
(HEAD, head_freqs) if path is not None:
]) self.save_to_directory(path)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)