mirror of https://github.com/explosion/spaCy.git
Refactor Language.end_training, making new save_to_directory method
This commit is contained in:
parent
49e2de900e
commit
33ba5066eb
|
@ -209,46 +209,34 @@ class Language(object):
|
||||||
lang = None
|
lang = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
def setup_directory(cls, path, **configs):
|
||||||
def train(cls, path, gold_tuples, *configs):
|
for name, config in configs.items():
|
||||||
if isinstance(path, basestring):
|
directory = path / name
|
||||||
path = pathlib.Path(path)
|
if directory.exists():
|
||||||
tagger_cfg, parser_cfg, entity_cfg = configs
|
shutil.rmtree(str(directory))
|
||||||
dep_model_dir = path / 'deps'
|
directory.mkdir()
|
||||||
ner_model_dir = path / 'ner'
|
with (directory / 'config.json').open('wb') as file_:
|
||||||
pos_model_dir = path / 'pos'
|
data = ujson.dumps(config, indent=2)
|
||||||
if dep_model_dir.exists():
|
if isinstance(data, unicode):
|
||||||
shutil.rmtree(str(dep_model_dir))
|
data = data.encode('utf8')
|
||||||
if ner_model_dir.exists():
|
file_.write(data)
|
||||||
shutil.rmtree(str(ner_model_dir))
|
if not (path / 'vocab').exists():
|
||||||
if pos_model_dir.exists():
|
(path / 'vocab').mkdir()
|
||||||
shutil.rmtree(str(pos_model_dir))
|
|
||||||
dep_model_dir.mkdir()
|
|
||||||
ner_model_dir.mkdir()
|
|
||||||
pos_model_dir.mkdir()
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def train(cls, path, gold_tuples, **configs):
|
||||||
if parser_cfg['pseudoprojective']:
|
if parser_cfg['pseudoprojective']:
|
||||||
# preprocess training data here before ArcEager.get_labels() is called
|
# preprocess training data here before ArcEager.get_labels() is called
|
||||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||||
|
|
||||||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
for subdir in ('deps', 'ner', 'pos'):
|
||||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
if subdir not in configs:
|
||||||
|
configs[subdir] = {}
|
||||||
|
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||||
|
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||||
|
|
||||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
cls.setup_directory(path, **configs)
|
||||||
data = ujson.dumps(parser_cfg)
|
|
||||||
if isinstance(data, unicode):
|
|
||||||
data = data.encode('utf8')
|
|
||||||
file_.write(data)
|
|
||||||
with (ner_model_dir / 'config.json').open('wb') as file_:
|
|
||||||
data = ujson.dumps(entity_cfg)
|
|
||||||
if isinstance(data, unicode):
|
|
||||||
data = data.encode('utf8')
|
|
||||||
file_.write(data)
|
|
||||||
with (pos_model_dir / 'config.json').open('wb') as file_:
|
|
||||||
data = ujson.dumps(tagger_cfg)
|
|
||||||
if isinstance(data, unicode):
|
|
||||||
data = data.encode('utf8')
|
|
||||||
file_.write(data)
|
|
||||||
|
|
||||||
self = cls(
|
self = cls(
|
||||||
path=path,
|
path=path,
|
||||||
|
@ -269,7 +257,9 @@ class Language(object):
|
||||||
self.entity = self.Defaults.create_entity(self)
|
self.entity = self.Defaults.create_entity(self)
|
||||||
self.pipeline = self.Defaults.create_pipeline(self)
|
self.pipeline = self.Defaults.create_pipeline(self)
|
||||||
yield Trainer(self, gold_tuples)
|
yield Trainer(self, gold_tuples)
|
||||||
self.end_training(path=path)
|
self.end_training()
|
||||||
|
self.save_to_directory(path, deps=self.parser.cfg, ner=self.entity.cfg,
|
||||||
|
pos=self.tagger.cfg)
|
||||||
|
|
||||||
def __init__(self, **overrides):
|
def __init__(self, **overrides):
|
||||||
if 'data_dir' in overrides and 'path' not in overrides:
|
if 'data_dir' in overrides and 'path' not in overrides:
|
||||||
|
@ -373,51 +363,36 @@ class Language(object):
|
||||||
for doc in stream:
|
for doc in stream:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def end_training(self, path=None):
|
def save_to_directory(self, path):
|
||||||
if path is None:
|
configs = {
|
||||||
path = self.path
|
'pos': self.tagger.cfg if self.tagger else {},
|
||||||
elif isinstance(path, basestring):
|
'deps': self.parser.cfg if self.parser else {},
|
||||||
path = pathlib.Path(path)
|
'ner': self.entity.cfg if self.entity else {},
|
||||||
|
}
|
||||||
|
|
||||||
if self.tagger:
|
self.setup_directory(path, **configs)
|
||||||
self.tagger.model.end_training()
|
|
||||||
self.tagger.model.dump(str(path / 'pos' / 'model'))
|
|
||||||
if self.parser:
|
|
||||||
self.parser.model.end_training()
|
|
||||||
self.parser.model.dump(str(path / 'deps' / 'model'))
|
|
||||||
if self.entity:
|
|
||||||
self.entity.model.end_training()
|
|
||||||
self.entity.model.dump(str(path / 'ner' / 'model'))
|
|
||||||
|
|
||||||
strings_loc = path / 'vocab' / 'strings.json'
|
strings_loc = path / 'vocab' / 'strings.json'
|
||||||
with strings_loc.open('w', encoding='utf8') as file_:
|
with strings_loc.open('w', encoding='utf8') as file_:
|
||||||
self.vocab.strings.dump(file_)
|
self.vocab.strings.dump(file_)
|
||||||
self.vocab.dump(path / 'vocab' / 'lexemes.bin')
|
self.vocab.dump(path / 'vocab' / 'lexemes.bin')
|
||||||
|
# TODO: Word vectors?
|
||||||
if self.tagger:
|
if self.tagger:
|
||||||
tagger_freqs = list(self.tagger.freqs[TAG].items())
|
self.tagger.model.dump(str(path / 'pos' / 'model'))
|
||||||
else:
|
|
||||||
tagger_freqs = []
|
|
||||||
if self.parser:
|
if self.parser:
|
||||||
dep_freqs = list(self.parser.moves.freqs[DEP].items())
|
self.parser.model.dump(str(path / 'deps' / 'model'))
|
||||||
head_freqs = list(self.parser.moves.freqs[HEAD].items())
|
|
||||||
else:
|
|
||||||
dep_freqs = []
|
|
||||||
head_freqs = []
|
|
||||||
if self.entity:
|
if self.entity:
|
||||||
entity_iob_freqs = list(self.entity.moves.freqs[ENT_IOB].items())
|
self.entity.model.dump(str(path / 'ner' / 'model'))
|
||||||
entity_type_freqs = list(self.entity.moves.freqs[ENT_TYPE].items())
|
|
||||||
else:
|
def end_training(self, path=None):
|
||||||
entity_iob_freqs = []
|
if self.tagger:
|
||||||
entity_type_freqs = []
|
self.tagger.model.end_training()
|
||||||
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
|
if self.parser:
|
||||||
data = ujson.dumps([
|
self.parser.model.end_training()
|
||||||
(TAG, tagger_freqs),
|
if self.entity:
|
||||||
(DEP, dep_freqs),
|
self.entity.model.end_training()
|
||||||
(ENT_IOB, entity_iob_freqs),
|
# NB: This is slightly different from before --- we no longer default
|
||||||
(ENT_TYPE, entity_type_freqs),
|
# to taking nlp.path
|
||||||
(HEAD, head_freqs)
|
if path is not None:
|
||||||
])
|
self.save_to_directory(path)
|
||||||
if isinstance(data, unicode):
|
|
||||||
data = data.encode('utf8')
|
|
||||||
file_.write(data)
|
|
||||||
|
|
Loading…
Reference in New Issue