mirror of https://github.com/explosion/spaCy.git
Refactor Language.end_training, making new save_to_directory method
This commit is contained in:
parent
49e2de900e
commit
33ba5066eb
|
@ -209,46 +209,34 @@ class Language(object):
|
|||
lang = None
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def train(cls, path, gold_tuples, *configs):
|
||||
if isinstance(path, basestring):
|
||||
path = pathlib.Path(path)
|
||||
tagger_cfg, parser_cfg, entity_cfg = configs
|
||||
dep_model_dir = path / 'deps'
|
||||
ner_model_dir = path / 'ner'
|
||||
pos_model_dir = path / 'pos'
|
||||
if dep_model_dir.exists():
|
||||
shutil.rmtree(str(dep_model_dir))
|
||||
if ner_model_dir.exists():
|
||||
shutil.rmtree(str(ner_model_dir))
|
||||
if pos_model_dir.exists():
|
||||
shutil.rmtree(str(pos_model_dir))
|
||||
dep_model_dir.mkdir()
|
||||
ner_model_dir.mkdir()
|
||||
pos_model_dir.mkdir()
|
||||
def setup_directory(cls, path, **configs):
|
||||
for name, config in configs.items():
|
||||
directory = path / name
|
||||
if directory.exists():
|
||||
shutil.rmtree(str(directory))
|
||||
directory.mkdir()
|
||||
with (directory / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(config, indent=2)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
if not (path / 'vocab').exists():
|
||||
(path / 'vocab').mkdir()
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def train(cls, path, gold_tuples, **configs):
|
||||
if parser_cfg['pseudoprojective']:
|
||||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
for subdir in ('deps', 'ner', 'pos'):
|
||||
if subdir not in configs:
|
||||
configs[subdir] = {}
|
||||
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
|
||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(parser_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
with (ner_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(entity_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
with (pos_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(tagger_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
cls.setup_directory(path, **configs)
|
||||
|
||||
self = cls(
|
||||
path=path,
|
||||
|
@ -269,7 +257,9 @@ class Language(object):
|
|||
self.entity = self.Defaults.create_entity(self)
|
||||
self.pipeline = self.Defaults.create_pipeline(self)
|
||||
yield Trainer(self, gold_tuples)
|
||||
self.end_training(path=path)
|
||||
self.end_training()
|
||||
self.save_to_directory(path, deps=self.parser.cfg, ner=self.entity.cfg,
|
||||
pos=self.tagger.cfg)
|
||||
|
||||
def __init__(self, **overrides):
|
||||
if 'data_dir' in overrides and 'path' not in overrides:
|
||||
|
@ -373,51 +363,36 @@ class Language(object):
|
|||
for doc in stream:
|
||||
yield doc
|
||||
|
||||
def end_training(self, path=None):
|
||||
if path is None:
|
||||
path = self.path
|
||||
elif isinstance(path, basestring):
|
||||
path = pathlib.Path(path)
|
||||
def save_to_directory(self, path):
|
||||
configs = {
|
||||
'pos': self.tagger.cfg if self.tagger else {},
|
||||
'deps': self.parser.cfg if self.parser else {},
|
||||
'ner': self.entity.cfg if self.entity else {},
|
||||
}
|
||||
|
||||
if self.tagger:
|
||||
self.tagger.model.end_training()
|
||||
self.tagger.model.dump(str(path / 'pos' / 'model'))
|
||||
if self.parser:
|
||||
self.parser.model.end_training()
|
||||
self.parser.model.dump(str(path / 'deps' / 'model'))
|
||||
if self.entity:
|
||||
self.entity.model.end_training()
|
||||
self.entity.model.dump(str(path / 'ner' / 'model'))
|
||||
self.setup_directory(path, **configs)
|
||||
|
||||
strings_loc = path / 'vocab' / 'strings.json'
|
||||
with strings_loc.open('w', encoding='utf8') as file_:
|
||||
self.vocab.strings.dump(file_)
|
||||
self.vocab.dump(path / 'vocab' / 'lexemes.bin')
|
||||
|
||||
# TODO: Word vectors?
|
||||
if self.tagger:
|
||||
tagger_freqs = list(self.tagger.freqs[TAG].items())
|
||||
else:
|
||||
tagger_freqs = []
|
||||
self.tagger.model.dump(str(path / 'pos' / 'model'))
|
||||
if self.parser:
|
||||
dep_freqs = list(self.parser.moves.freqs[DEP].items())
|
||||
head_freqs = list(self.parser.moves.freqs[HEAD].items())
|
||||
else:
|
||||
dep_freqs = []
|
||||
head_freqs = []
|
||||
self.parser.model.dump(str(path / 'deps' / 'model'))
|
||||
if self.entity:
|
||||
entity_iob_freqs = list(self.entity.moves.freqs[ENT_IOB].items())
|
||||
entity_type_freqs = list(self.entity.moves.freqs[ENT_TYPE].items())
|
||||
else:
|
||||
entity_iob_freqs = []
|
||||
entity_type_freqs = []
|
||||
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
|
||||
data = ujson.dumps([
|
||||
(TAG, tagger_freqs),
|
||||
(DEP, dep_freqs),
|
||||
(ENT_IOB, entity_iob_freqs),
|
||||
(ENT_TYPE, entity_type_freqs),
|
||||
(HEAD, head_freqs)
|
||||
])
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
self.entity.model.dump(str(path / 'ner' / 'model'))
|
||||
|
||||
def end_training(self, path=None):
|
||||
if self.tagger:
|
||||
self.tagger.model.end_training()
|
||||
if self.parser:
|
||||
self.parser.model.end_training()
|
||||
if self.entity:
|
||||
self.entity.model.end_training()
|
||||
# NB: This is slightly different from before --- we no longer default
|
||||
# to taking nlp.path
|
||||
if path is not None:
|
||||
self.save_to_directory(path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue