* Move spacy.syntax.conll to spacy.gold

This commit is contained in:
Matthew Honnibal 2015-05-24 21:35:02 +02:00
parent 765b61cac4
commit fc75210941
8 changed files with 26 additions and 16 deletions

View File

@ -20,8 +20,8 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError from spacy.syntax.parser import OracleError
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.syntax.conll import read_json_file from spacy.gold import read_json_file
from spacy.syntax.conll import GoldParse from spacy.gold import GoldParse
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -65,11 +65,13 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
gold_tuples = gold_tuples[:n_sents] gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir) nlp = Language(data_dir=model_dir)
print "Itn.\tUAS\tNER F.\tTag %\tToken %" print "Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %"
for itn in range(n_iter): for itn in range(n_iter):
scorer = Scorer() scorer = Scorer()
loss = 0
for raw_text, annot_tuples, ctnt in gold_tuples: for raw_text, annot_tuples, ctnt in gold_tuples:
raw_text = ''.join(add_noise(c, corruption_level) for c in raw_text) if corruption_level != 0:
raw_text = ''.join(add_noise(c, corruption_level) for c in raw_text)
tokens = nlp(raw_text, merge_mwes=False) tokens = nlp(raw_text, merge_mwes=False)
gold = GoldParse(tokens, annot_tuples) gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False) scorer.score(tokens, gold, verbose=False)
@ -79,7 +81,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
gold = GoldParse(tokens, annot_tuples) gold = GoldParse(tokens, annot_tuples)
nlp.tagger(tokens) nlp.tagger(tokens)
try: try:
nlp.parser.train(tokens, gold) loss += nlp.parser.train(tokens, gold)
except AssertionError: except AssertionError:
# TODO: Do something about non-projective sentences # TODO: Do something about non-projective sentences
continue continue
@ -87,7 +89,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
nlp.entity.train(tokens, gold) nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags) nlp.tagger.train(tokens, gold.tags)
print '%d:\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, print '%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
scorer.tags_acc, scorer.tags_acc,
scorer.token_acc) scorer.token_acc)
random.shuffle(gold_tuples) random.shuffle(gold_tuples)
@ -148,15 +150,16 @@ def get_sents(json_loc):
model_dir=("Location of output model directory",), model_dir=("Location of output model directory",),
out_loc=("Out location", "option", "o", str), out_loc=("Out location", "option", "o", str),
n_sents=("Number of training sentences", "option", "n", int), n_sents=("Number of training sentences", "option", "n", int),
n_iter=("Number of training iterations", "option", "i", int),
verbose=("Verbose error reporting", "flag", "v", bool), verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool) debug=("Debug mode", "flag", "d", bool)
) )
def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False, def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0): debug=False, corruption_level=0.0):
train(English, read_json_file(train_loc), model_dir, train(English, read_json_file(train_loc), model_dir,
feat_set='basic' if not debug else 'debug', feat_set='basic' if not debug else 'debug',
gold_preproc=False, n_sents=n_sents, gold_preproc=False, n_sents=n_sents,
corruption_level=corruption_level) corruption_level=corruption_level, n_iter=n_iter)
if out_loc: if out_loc:
write_parses(English, dev_loc, model_dir, out_loc) write_parses(English, dev_loc, model_dir, out_loc)
scorer = evaluate(English, read_json_file(dev_loc), scorer = evaluate(English, read_json_file(dev_loc),

View File

@ -152,7 +152,7 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax._state', 'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax._state',
'spacy.syntax.transition_system', 'spacy.syntax.transition_system',
'spacy.syntax.arc_eager', 'spacy.syntax._parse_features', 'spacy.syntax.arc_eager', 'spacy.syntax._parse_features',
'spacy.syntax.conll', 'spacy.orth', 'spacy.gold', 'spacy.orth',
'spacy.syntax.ner'] 'spacy.syntax.ner']

View File

@ -1,7 +1,7 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..structs cimport TokenC from .structs cimport TokenC
from .transition_system cimport Transition from .syntax.transition_system cimport Transition
cimport numpy cimport numpy

View File

@ -2,7 +2,7 @@ import numpy
import codecs import codecs
import json import json
import random import random
from spacy.munge.alignment import align from .munge.alignment import align
from libc.string cimport memset from libc.string cimport memset

View File

@ -10,7 +10,7 @@ from ._state cimport count_left_kids
from ..structs cimport TokenC from ..structs cimport TokenC
from .transition_system cimport do_func_t, get_cost_func_t from .transition_system cimport do_func_t, get_cost_func_t
from .conll cimport GoldParse from ..gold cimport GoldParse
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True

View File

@ -8,7 +8,7 @@ from .transition_system cimport do_func_t
from ..structs cimport TokenC, Entity from ..structs cimport TokenC, Entity
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from .conll cimport GoldParse from ..gold cimport GoldParse
cdef enum: cdef enum:

View File

@ -30,7 +30,7 @@ from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError from .transition_system import OracleError
from ._state cimport new_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1 from ._state cimport new_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
from .conll cimport GoldParse from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE from ._parse_features cimport fill_context, CONTEXT_SIZE
@ -107,14 +107,21 @@ cdef class GreedyParser:
cdef Transition guess cdef Transition guess
cdef Transition best cdef Transition best
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
loss = 0
while not is_final(state): while not is_final(state):
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(scores, state, gold) best = self.moves.best_gold(scores, state, gold)
#print self.moves.move_name(guess.move, guess.label),
#print self.moves.move_name(best.move, best.label),
#print print_state(state, py_words)
cost = guess.get_cost(&guess, state, gold) cost = guess.get_cost(&guess, state, gold)
self.model.update(context, guess.clas, best.clas, cost) self.model.update(context, guess.clas, best.clas, cost)
guess.do(&guess, state) guess.do(&guess, state)
loss += cost
self.moves.finalize_state(state) self.moves.finalize_state(state)
return loss

View File

@ -3,7 +3,7 @@ from thinc.typedefs cimport weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport State from ._state cimport State
from .conll cimport GoldParse from ..gold cimport GoldParse
from ..strings cimport StringStore from ..strings cimport StringStore