mirror of https://github.com/explosion/spaCy.git
* Upd train script, moving lots of functionality to new GoldParse class
This commit is contained in:
parent
f0159ab4b6
commit
34215de61b
|
@ -61,13 +61,8 @@ def read_docparse_gold(file_):
|
||||||
tags = []
|
tags = []
|
||||||
ids = []
|
ids = []
|
||||||
lines = sent_str.strip().split('\n')
|
lines = sent_str.strip().split('\n')
|
||||||
<<<<<<< HEAD
|
|
||||||
raw_text = lines.pop(0).strip()
|
raw_text = lines.pop(0).strip()
|
||||||
tok_text = lines.pop(0).strip()
|
tok_text = lines.pop(0).strip()
|
||||||
=======
|
|
||||||
raw_text = lines.pop(0)
|
|
||||||
tok_text = lines.pop(0)
|
|
||||||
>>>>>>> master
|
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
id_, word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
if label == 'root':
|
if label == 'root':
|
||||||
|
@ -200,9 +195,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
||||||
pos_model_dir)
|
pos_model_dir)
|
||||||
|
|
||||||
left_labels, right_labels = get_labels(paragraphs)
|
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
left_labels=left_labels, right_labels=right_labels)
|
labels=labels)
|
||||||
|
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
|
@ -210,14 +205,12 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
heads_corr = 0
|
heads_corr = 0
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
for gold_sent in gold_sents:
|
||||||
gold_preproc=gold_preproc):
|
tokens = nlp.tokenizer(gold_sent.raw)
|
||||||
|
gold_sent.align_to_tokens(tokens)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
try:
|
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=force_gold)
|
pos_corr += nlp.tagger.train(tokens, gold_parse.tags)
|
||||||
except OracleError:
|
|
||||||
continue
|
|
||||||
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
pos_acc = float(pos_corr) / n_tokens
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
|
@ -265,10 +258,9 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
train(English, read_docparse_gold(train_loc), model_dir,
|
||||||
train_sents = read_docparse_gold(file_)
|
gold_preproc=False, force_gold=False)
|
||||||
train(English, train_sents, model_dir, gold_preproc=False, force_gold=False)
|
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False)
|
||||||
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue