Improve printing in train_ud script

This commit is contained in:
Matthew Honnibal 2017-03-11 11:11:05 -06:00
parent ca9c8c57c0
commit a155482fda
1 changed files with 12 additions and 7 deletions

View File

@ -24,9 +24,10 @@ import io
def read_conllx(loc):
def read_conllx(loc, n=0):
with io.open(loc, 'r', encoding='utf8') as file_:
text = file_.read()
i = 0
for sent in text.strip().split('\n\n'):
lines = sent.strip().split('\n')
if lines:
@ -47,6 +48,9 @@ def read_conllx(loc):
raise
tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]])
i += 1
if n >= 1 and i >= n:
break
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
@ -73,7 +77,7 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
actions = ArcEager.get_actions(gold_parses=train_sents)
features = get_templates('basic')
model_dir = pathlib.Path(model_dir)
if not (model_dir / 'deps').exists():
(model_dir / 'deps').mkdir()
@ -95,25 +99,26 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
for tag in tags:
assert tag in tag_map, repr(tag)
tagger = Tagger(vocab, tag_map=tag_map)
parser = DependencyParser(vocab, actions=actions, features=features)
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
for itn in range(15):
loss = 0.
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
doc = Doc(vocab, words=words)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
tagger(doc)
parser.update(doc, gold)
loss += parser.update(doc, gold, itn=itn)
doc = Doc(vocab, words=words)
tagger.update(doc, gold)
random.shuffle(train_sents)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
nlp.end_training(model_dir)
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
if __name__ == '__main__':
plac.call(main)