Update train_ud for Universal Dependencies 2

This commit is contained in:
Matthew Honnibal 2017-03-16 17:08:15 -05:00
parent 890747d8ff
commit ef6bd08e6c
1 changed files with 34 additions and 14 deletions

View File

@ -14,7 +14,7 @@ from spacy.language import Language
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser, BeamDependencyParser
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -35,8 +35,8 @@ def read_conllx(loc, n=0):
lines.pop(0) lines.pop(0)
tokens = [] tokens = []
for line in lines: for line in lines:
id_, word, lemma, tag, pos, morph, head, dep, _1, _2 = line.split() id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
if '-' in id_: if '-' in id_ or '.' in id_:
continue continue
try: try:
id_ = int(id_) - 1 id_ = int(id_) - 1
@ -66,12 +66,8 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False):
return scorer return scorer
def main(train_loc, dev_loc, model_dir, tag_map_loc=None): def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
if tag_map_loc: LangClass = spacy.util.get_lang_class(lang_name)
with open(tag_map_loc) as file_:
tag_map = json.loads(file_.read())
else:
tag_map = DEFAULT_TAG_MAP
train_sents = list(read_conllx(train_loc)) train_sents = list(read_conllx(train_loc))
train_sents = PseudoProjectivity.preprocess_training_data(train_sents) train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
@ -79,13 +75,37 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
features = get_templates('basic') features = get_templates('basic')
model_dir = pathlib.Path(model_dir) model_dir = pathlib.Path(model_dir)
if not model_dir.exists():
model_dir.mkdir()
if not (model_dir / 'deps').exists(): if not (model_dir / 'deps').exists():
(model_dir / 'deps').mkdir() (model_dir / 'deps').mkdir()
if not (model_dir / 'pos').exists():
(model_dir / 'pos').mkdir()
with (model_dir / 'deps' / 'config.json').open('wb') as file_: with (model_dir / 'deps' / 'config.json').open('wb') as file_:
file_.write( file_.write(
json.dumps( json.dumps(
{'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8')) {'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8'))
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
vocab = LangClass.Defaults.create_vocab()
if not (model_dir / 'vocab').exists():
(model_dir / 'vocab').mkdir()
else:
if (model_dir / 'vocab' / 'strings.json').exists():
with (model_dir / 'vocab' / 'strings.json').open() as file_:
vocab.strings.load(file_)
if (model_dir / 'vocab' / 'lexemes.bin').exists():
vocab.load_lexemes(model_dir / 'vocab' / 'lexemes.bin')
if clusters_loc is not None:
clusters_loc = pathlib.Path(clusters_loc)
with clusters_loc.open() as file_:
for line in file_:
try:
cluster, word, freq = line.split()
except ValueError:
continue
lex = vocab[word]
lex.cluster = int(cluster[::-1], 2)
# Populate vocab # Populate vocab
for _, doc_sents in train_sents: for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents: for (ids, words, tags, heads, deps, ner), _ in doc_sents:
@ -95,13 +115,13 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
_ = vocab[dep] _ = vocab[dep]
for tag in tags: for tag in tags:
_ = vocab[tag] _ = vocab[tag]
if tag_map: if vocab.morphology.tag_map:
for tag in tags: for tag in tags:
assert tag in tag_map, repr(tag) assert tag in vocab.morphology.tag_map, repr(tag)
tagger = Tagger(vocab, tag_map=tag_map) tagger = Tagger(vocab)
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
for itn in range(15): for itn in range(30):
loss = 0. loss = 0.
for _, doc_sents in train_sents: for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents: for (ids, words, tags, heads, deps, ner), _ in doc_sents: