Refactor Language.end_training, making new save_to_directory method

This commit is contained in:
Matthew Honnibal 2017-04-14 23:51:24 +02:00
parent 49e2de900e
commit 33ba5066eb
1 changed files with 50 additions and 75 deletions

View File

@ -209,46 +209,34 @@ class Language(object):
lang = None
@classmethod
@contextmanager
def train(cls, path, gold_tuples, *configs):
if isinstance(path, basestring):
path = pathlib.Path(path)
tagger_cfg, parser_cfg, entity_cfg = configs
dep_model_dir = path / 'deps'
ner_model_dir = path / 'ner'
pos_model_dir = path / 'pos'
if dep_model_dir.exists():
shutil.rmtree(str(dep_model_dir))
if ner_model_dir.exists():
shutil.rmtree(str(ner_model_dir))
if pos_model_dir.exists():
shutil.rmtree(str(pos_model_dir))
dep_model_dir.mkdir()
ner_model_dir.mkdir()
pos_model_dir.mkdir()
def setup_directory(cls, path, **configs):
for name, config in configs.items():
directory = path / name
if directory.exists():
shutil.rmtree(str(directory))
directory.mkdir()
with (directory / 'config.json').open('wb') as file_:
data = ujson.dumps(config, indent=2)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
if not (path / 'vocab').exists():
(path / 'vocab').mkdir()
@classmethod
@contextmanager
def train(cls, path, gold_tuples, **configs):
if parser_cfg['pseudoprojective']:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
for subdir in ('deps', 'ner', 'pos'):
if subdir not in configs:
configs[subdir] = {}
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(parser_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (ner_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(entity_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (pos_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(tagger_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
cls.setup_directory(path, **configs)
self = cls(
path=path,
@ -269,7 +257,9 @@ class Language(object):
self.entity = self.Defaults.create_entity(self)
self.pipeline = self.Defaults.create_pipeline(self)
yield Trainer(self, gold_tuples)
self.end_training(path=path)
self.end_training()
self.save_to_directory(path, deps=self.parser.cfg, ner=self.entity.cfg,
pos=self.tagger.cfg)
def __init__(self, **overrides):
if 'data_dir' in overrides and 'path' not in overrides:
@ -373,51 +363,36 @@ class Language(object):
for doc in stream:
yield doc
def end_training(self, path=None):
if path is None:
path = self.path
elif isinstance(path, basestring):
path = pathlib.Path(path)
def save_to_directory(self, path):
configs = {
'pos': self.tagger.cfg if self.tagger else {},
'deps': self.parser.cfg if self.parser else {},
'ner': self.entity.cfg if self.entity else {},
}
if self.tagger:
self.tagger.model.end_training()
self.tagger.model.dump(str(path / 'pos' / 'model'))
if self.parser:
self.parser.model.end_training()
self.parser.model.dump(str(path / 'deps' / 'model'))
if self.entity:
self.entity.model.end_training()
self.entity.model.dump(str(path / 'ner' / 'model'))
self.setup_directory(path, **configs)
strings_loc = path / 'vocab' / 'strings.json'
with strings_loc.open('w', encoding='utf8') as file_:
self.vocab.strings.dump(file_)
self.vocab.dump(path / 'vocab' / 'lexemes.bin')
# TODO: Word vectors?
if self.tagger:
tagger_freqs = list(self.tagger.freqs[TAG].items())
else:
tagger_freqs = []
self.tagger.model.dump(str(path / 'pos' / 'model'))
if self.parser:
dep_freqs = list(self.parser.moves.freqs[DEP].items())
head_freqs = list(self.parser.moves.freqs[HEAD].items())
else:
dep_freqs = []
head_freqs = []
self.parser.model.dump(str(path / 'deps' / 'model'))
if self.entity:
entity_iob_freqs = list(self.entity.moves.freqs[ENT_IOB].items())
entity_type_freqs = list(self.entity.moves.freqs[ENT_TYPE].items())
else:
entity_iob_freqs = []
entity_type_freqs = []
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
data = ujson.dumps([
(TAG, tagger_freqs),
(DEP, dep_freqs),
(ENT_IOB, entity_iob_freqs),
(ENT_TYPE, entity_type_freqs),
(HEAD, head_freqs)
])
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
self.entity.model.dump(str(path / 'ner' / 'model'))
def end_training(self, path=None):
if self.tagger:
self.tagger.model.end_training()
if self.parser:
self.parser.model.end_training()
if self.entity:
self.entity.model.end_training()
# NB: This is slightly different from before --- we no longer default
# to taking nlp.path
if path is not None:
self.save_to_directory(path)