mirror of https://github.com/explosion/spaCy.git
* Report LAS in train script
This commit is contained in:
parent
b07632a9ef
commit
053814ffc8
|
@ -226,7 +226,8 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
global loss
|
global loss
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
n_corr = 0
|
uas_corr = 0
|
||||||
|
las_corr = 0
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
@ -251,11 +252,14 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
continue
|
continue
|
||||||
if is_punct_label(labels[i]):
|
if is_punct_label(labels[i]):
|
||||||
continue
|
continue
|
||||||
n_corr += token.head.i == heads[i]
|
uas_corr += token.head.i == heads[i]
|
||||||
|
las_corr += token.head.i == heads[i] and token.dep_ == labels[i]
|
||||||
|
#print token.orth_, token.head.orth_, token.dep_, labels[i]
|
||||||
total += 1
|
total += 1
|
||||||
print loss, skipped, (loss+skipped + total)
|
print loss, skipped, (loss+skipped + total)
|
||||||
print pos_corr / n_tokens
|
print pos_corr / n_tokens
|
||||||
return float(n_corr) / (total + loss)
|
print float(las_corr) / (total + loss)
|
||||||
|
return float(uas_corr) / (total + loss)
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
|
|
Loading…
Reference in New Issue