mirror of https://github.com/explosion/spaCy.git
Fix PhraseMatcher to remember attr on pickling (#4336)
* Fix PhraseMatcher to remember attr on pickling * Check for attr as int or long
This commit is contained in:
parent
089f44cc56
commit
ba5595c764
|
@ -49,7 +49,7 @@ cdef class PhraseMatcher:
|
||||||
self._terminal_hash = 826361138722620965
|
self._terminal_hash = 826361138722620965
|
||||||
map_init(self.mem, self.c_map, 8)
|
map_init(self.mem, self.c_map, 8)
|
||||||
|
|
||||||
if isinstance(attr, long):
|
if isinstance(attr, (int, long)):
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
else:
|
else:
|
||||||
attr = attr.upper()
|
attr = attr.upper()
|
||||||
|
@ -79,7 +79,7 @@ cdef class PhraseMatcher:
|
||||||
return key in self._callbacks
|
return key in self._callbacks
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
data = (self.vocab, self._docs, self._callbacks)
|
data = (self.vocab, self._docs, self._callbacks, self.attr)
|
||||||
return (unpickle_matcher, data, None, None)
|
return (unpickle_matcher, data, None, None)
|
||||||
|
|
||||||
def remove(self, key):
|
def remove(self, key):
|
||||||
|
@ -171,15 +171,15 @@ cdef class PhraseMatcher:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if len(doc) == 0:
|
if len(doc) == 0:
|
||||||
continue
|
continue
|
||||||
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
|
||||||
raise ValueError(Errors.E155.format())
|
|
||||||
if self.attr == DEP and not doc.is_parsed:
|
|
||||||
raise ValueError(Errors.E156.format())
|
|
||||||
if self._validate and (doc.is_tagged or doc.is_parsed) \
|
|
||||||
and self.attr not in (DEP, POS, TAG, LEMMA):
|
|
||||||
string_attr = self.vocab.strings[self.attr]
|
|
||||||
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
|
||||||
if isinstance(doc, Doc):
|
if isinstance(doc, Doc):
|
||||||
|
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
||||||
|
raise ValueError(Errors.E155.format())
|
||||||
|
if self.attr == DEP and not doc.is_parsed:
|
||||||
|
raise ValueError(Errors.E156.format())
|
||||||
|
if self._validate and (doc.is_tagged or doc.is_parsed) \
|
||||||
|
and self.attr not in (DEP, POS, TAG, LEMMA):
|
||||||
|
string_attr = self.vocab.strings[self.attr]
|
||||||
|
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
||||||
keyword = self._convert_to_array(doc)
|
keyword = self._convert_to_array(doc)
|
||||||
else:
|
else:
|
||||||
keyword = doc
|
keyword = doc
|
||||||
|
@ -310,8 +310,8 @@ cdef class PhraseMatcher:
|
||||||
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
||||||
|
|
||||||
|
|
||||||
def unpickle_matcher(vocab, docs, callbacks):
|
def unpickle_matcher(vocab, docs, callbacks, attr):
|
||||||
matcher = PhraseMatcher(vocab)
|
matcher = PhraseMatcher(vocab, attr=attr)
|
||||||
for key, specs in docs.items():
|
for key, specs in docs.items():
|
||||||
callback = callbacks.get(key, None)
|
callback = callbacks.get(key, None)
|
||||||
matcher.add(key, callback, *specs)
|
matcher.add(key, callback, *specs)
|
||||||
|
|
Loading…
Reference in New Issue