mirror of https://github.com/explosion/spaCy.git
* Add support for tag dictionary, and fix error-code for predict method
This commit is contained in:
parent
f00afe12c4
commit
3819a88e1b
|
@ -3,6 +3,7 @@ from cymem.cymem cimport Pool
|
||||||
from thinc.learner cimport LinearModel
|
from thinc.learner cimport LinearModel
|
||||||
from thinc.features cimport Extractor
|
from thinc.features cimport Extractor
|
||||||
from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t
|
from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t
|
||||||
|
from preshed.maps cimport PreshMap
|
||||||
|
|
||||||
from .typedefs cimport hash_t
|
from .typedefs cimport hash_t
|
||||||
from .tokens cimport Tokens
|
from .tokens cimport Tokens
|
||||||
|
@ -15,7 +16,7 @@ cpdef enum TagType:
|
||||||
|
|
||||||
cdef class Tagger:
|
cdef class Tagger:
|
||||||
cpdef int set_tags(self, Tokens tokens) except -1
|
cpdef int set_tags(self, Tokens tokens) except -1
|
||||||
cpdef class_t predict(self, int i, Tokens tokens, object golds=*) except 0
|
cpdef class_t predict(self, int i, Tokens tokens, object golds=*) except *
|
||||||
|
|
||||||
cpdef readonly Pool mem
|
cpdef readonly Pool mem
|
||||||
cpdef readonly Extractor extractor
|
cpdef readonly Extractor extractor
|
||||||
|
@ -23,3 +24,4 @@ cdef class Tagger:
|
||||||
|
|
||||||
cpdef readonly TagType tag_type
|
cpdef readonly TagType tag_type
|
||||||
cpdef readonly list tag_names
|
cpdef readonly list tag_names
|
||||||
|
cdef dict tagdict
|
||||||
|
|
|
@ -18,7 +18,7 @@ from thinc.features cimport Feature, count_feats
|
||||||
NULL_TAG = 0
|
NULL_TAG = 0
|
||||||
|
|
||||||
|
|
||||||
def setup_model_dir(tag_type, tag_names, templates, model_dir):
|
def setup_model_dir(tag_type, tag_names, tag_counts, templates, model_dir):
|
||||||
if path.exists(model_dir):
|
if path.exists(model_dir):
|
||||||
shutil.rmtree(model_dir)
|
shutil.rmtree(model_dir)
|
||||||
os.mkdir(model_dir)
|
os.mkdir(model_dir)
|
||||||
|
@ -26,6 +26,7 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
|
||||||
'tag_type': tag_type,
|
'tag_type': tag_type,
|
||||||
'templates': templates,
|
'templates': templates,
|
||||||
'tag_names': tag_names,
|
'tag_names': tag_names,
|
||||||
|
'tag_counts': tag_counts,
|
||||||
}
|
}
|
||||||
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
||||||
json.dump(config, file_)
|
json.dump(config, file_)
|
||||||
|
@ -35,24 +36,19 @@ def train(train_sents, model_dir, nr_iter=10):
|
||||||
cdef Tokens tokens
|
cdef Tokens tokens
|
||||||
cdef Tagger tagger = Tagger(model_dir)
|
cdef Tagger tagger = Tagger(model_dir)
|
||||||
cdef int i
|
cdef int i
|
||||||
|
cdef class_t guess = 0
|
||||||
|
cdef class_t gold
|
||||||
for _ in range(nr_iter):
|
for _ in range(nr_iter):
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
total = 0
|
total = 0
|
||||||
for tokens, golds in train_sents:
|
for tokens, golds in train_sents:
|
||||||
assert len(tokens) == len(golds), [t.string for t in tokens]
|
assert len(tokens) == len(golds), [t.string for t in tokens]
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
if tagger.tag_type == POS:
|
gold = golds[i]
|
||||||
gold = _get_gold_pos(i, golds)
|
guess = tagger.predict(i, tokens, [gold])
|
||||||
else:
|
|
||||||
raise StandardError
|
|
||||||
|
|
||||||
guess = tagger.predict(i, tokens)
|
|
||||||
tokens.set_tag(i, tagger.tag_type, guess)
|
tokens.set_tag(i, tagger.tag_type, guess)
|
||||||
if gold is not None:
|
total += 1
|
||||||
tagger.tell_answer(gold)
|
n_corr += guess == gold
|
||||||
total += 1
|
|
||||||
n_corr += guess in gold
|
|
||||||
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
|
|
||||||
print('%.4f' % ((n_corr / total) * 100))
|
print('%.4f' % ((n_corr / total) * 100))
|
||||||
random.shuffle(train_sents)
|
random.shuffle(train_sents)
|
||||||
tagger.model.end_training()
|
tagger.model.end_training()
|
||||||
|
@ -96,8 +92,9 @@ cdef class Tagger:
|
||||||
templates = cfg['templates']
|
templates = cfg['templates']
|
||||||
self.tag_names = cfg['tag_names']
|
self.tag_names = cfg['tag_names']
|
||||||
self.tag_type = cfg['tag_type']
|
self.tag_type = cfg['tag_type']
|
||||||
|
self.tagdict = _make_tag_dict(cfg['tag_counts'])
|
||||||
self.extractor = Extractor(templates)
|
self.extractor = Extractor(templates)
|
||||||
self.model = LinearModel(len(self.tag_names))
|
self.model = LinearModel(len(self.tag_names), self.extractor.n_templ+2)
|
||||||
if path.exists(path.join(model_dir, 'model')):
|
if path.exists(path.join(model_dir, 'model')):
|
||||||
self.model.load(path.join(model_dir, 'model'))
|
self.model.load(path.join(model_dir, 'model'))
|
||||||
|
|
||||||
|
@ -113,7 +110,7 @@ cdef class Tagger:
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
tokens.set_tag(i, self.tag_type, self.predict(i, tokens))
|
tokens.set_tag(i, self.tag_type, self.predict(i, tokens))
|
||||||
|
|
||||||
cpdef class_t predict(self, int i, Tokens tokens, object golds=None) except 0:
|
cpdef class_t predict(self, int i, Tokens tokens, object golds=None) except *:
|
||||||
"""Predict the tag of tokens[i]. The tagger remembers the features and
|
"""Predict the tag of tokens[i]. The tagger remembers the features and
|
||||||
prediction, in case you later call tell_answer.
|
prediction, in case you later call tell_answer.
|
||||||
|
|
||||||
|
@ -121,16 +118,18 @@ cdef class Tagger:
|
||||||
>>> tag = EN.pos_tagger.predict(0, tokens)
|
>>> tag = EN.pos_tagger.predict(0, tokens)
|
||||||
>>> assert tag == EN.pos_tagger.tag_id('DT') == 5
|
>>> assert tag == EN.pos_tagger.tag_id('DT') == 5
|
||||||
"""
|
"""
|
||||||
cdef int n_feats
|
cdef atom_t sic = tokens.data[i].lex.sic
|
||||||
|
if sic in self.tagdict:
|
||||||
|
return self.tagdict[sic]
|
||||||
cdef atom_t[N_FIELDS] context
|
cdef atom_t[N_FIELDS] context
|
||||||
print sizeof(context)
|
|
||||||
fill_context(context, i, tokens.data)
|
fill_context(context, i, tokens.data)
|
||||||
|
cdef int n_feats
|
||||||
cdef Feature* feats = self.extractor.get_feats(context, &n_feats)
|
cdef Feature* feats = self.extractor.get_feats(context, &n_feats)
|
||||||
cdef weight_t* scores = self.model.get_scores(feats, n_feats)
|
cdef weight_t* scores = self.model.get_scores(feats, n_feats)
|
||||||
cdef class_t guess = _arg_max(scores, self.nr_class)
|
guess = _arg_max(scores, self.model.nr_class)
|
||||||
if golds is not None and guess not in golds:
|
if golds is not None and guess not in golds:
|
||||||
best = _arg_max_among(scores, golds)
|
best = _arg_max_among(scores, golds)
|
||||||
counts = {}
|
counts = {guess: {}, best: {}}
|
||||||
count_feats(counts[guess], feats, n_feats, -1)
|
count_feats(counts[guess], feats, n_feats, -1)
|
||||||
count_feats(counts[best], feats, n_feats, 1)
|
count_feats(counts[best], feats, n_feats, 1)
|
||||||
self.model.update(counts)
|
self.model.update(counts)
|
||||||
|
@ -145,12 +144,28 @@ cdef class Tagger:
|
||||||
return tag_id
|
return tag_id
|
||||||
|
|
||||||
|
|
||||||
cdef class_t _arg_max(weight_t* scores, int n_classes):
|
def _make_tag_dict(counts):
|
||||||
|
freq_thresh = 50
|
||||||
|
ambiguity_thresh = 0.98
|
||||||
|
tagdict = {}
|
||||||
|
cdef atom_t word
|
||||||
|
cdef atom_t tag
|
||||||
|
for word_str, tag_freqs in counts.items():
|
||||||
|
tag_str, mode = max(tag_freqs.items(), key=lambda item: item[1])
|
||||||
|
n = sum(tag_freqs.values())
|
||||||
|
word = int(word_str)
|
||||||
|
tag = int(tag_str)
|
||||||
|
if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh:
|
||||||
|
tagdict[word] = tag
|
||||||
|
return tagdict
|
||||||
|
|
||||||
|
|
||||||
|
cdef class_t _arg_max(weight_t* scores, int n_classes) except 9000:
|
||||||
cdef int best = 0
|
cdef int best = 0
|
||||||
cdef weight_t score = scores[best]
|
cdef weight_t score = scores[best]
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(1, n_classes):
|
for i in range(1, n_classes):
|
||||||
if scores[i] > score:
|
if scores[i] >= score:
|
||||||
score = scores[i]
|
score = scores[i]
|
||||||
best = i
|
best = i
|
||||||
return best
|
return best
|
||||||
|
|
Loading…
Reference in New Issue