Refactor Chinese initialization

This commit is contained in:
Adriane Boyd 2020-09-30 11:46:45 +02:00
parent 34f9c26c62
commit 6b7bb32834
4 changed files with 66 additions and 66 deletions

View File

@ -672,14 +672,22 @@ class Errors:
E999 = ("Unable to merge the `Doc` objects because they do not all share " E999 = ("Unable to merge the `Doc` objects because they do not all share "
"the same `Vocab`.") "the same `Vocab`.")
E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was " E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was "
"specified. Provide the name of a pretrained model or the path to " "loaded. Provide the name of a pretrained model or the path to "
"a model when initializing the pipeline:\n" "a model and initialize the pipeline:\n\n"
'config = {\n' 'config = {\n'
' "@tokenizers": "spacy.zh.ChineseTokenizer",\n' ' "nlp": {\n'
' "segmenter": "pkuseg",\n' ' "tokenizer": {\n'
' "pkuseg_model": "default", # or "/path/to/pkuseg_model" \n' ' "@tokenizers": "spacy.zh.ChineseTokenizer",\n'
' "segmenter": "pkuseg",\n'
' }\n'
' },\n'
' "initialize": {"tokenizer": {\n'
' "pkuseg_model": "default", # or /path/to/model\n'
' }\n'
' },\n'
'}\n' '}\n'
'nlp = Chinese.from_config({"nlp": {"tokenizer": config}})') 'nlp = Chinese.from_config(config)\n'
'nlp.initialize()')
E1001 = ("Target token outside of matched span for match with tokens " E1001 = ("Target token outside of matched span for match with tokens "
"'{span}' and offset '{index}' matched by patterns '{patterns}'.") "'{span}' and offset '{index}' matched by patterns '{patterns}'.")
E1002 = ("Span index out of range.") E1002 = ("Span index out of range.")

View File

@ -59,32 +59,13 @@ class ChineseTokenizer(DummyTokenizer):
self, self,
nlp: Language, nlp: Language,
segmenter: Segmenter = Segmenter.char, segmenter: Segmenter = Segmenter.char,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None,
): ):
self.vocab = nlp.vocab self.vocab = nlp.vocab
if isinstance(segmenter, Segmenter): if isinstance(segmenter, Segmenter):
segmenter = segmenter.value segmenter = segmenter.value
self.segmenter = segmenter self.segmenter = segmenter
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.pkuseg_seg = None self.pkuseg_seg = None
self.jieba_seg = None self.jieba_seg = None
self.configure_segmenter(segmenter)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None
):
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.configure_segmenter(self.segmenter)
def configure_segmenter(self, segmenter: str):
if segmenter not in Segmenter.values(): if segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format( warn_msg = Warnings.W103.format(
lang="Chinese", lang="Chinese",
@ -94,12 +75,21 @@ class ChineseTokenizer(DummyTokenizer):
) )
warnings.warn(warn_msg) warnings.warn(warn_msg)
self.segmenter = Segmenter.char self.segmenter = Segmenter.char
self.jieba_seg = try_jieba_import(self.segmenter) if segmenter == Segmenter.jieba:
self.pkuseg_seg = try_pkuseg_import( self.jieba_seg = try_jieba_import()
self.segmenter,
pkuseg_model=self.pkuseg_model, def initialize(
pkuseg_user_dict=self.pkuseg_user_dict, self,
) get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: str = "default",
):
if self.segmenter == Segmenter.pkuseg:
self.pkuseg_seg = try_pkuseg_import(
pkuseg_model=pkuseg_model, pkuseg_user_dict=pkuseg_user_dict,
)
def __call__(self, text: str) -> Doc: def __call__(self, text: str) -> Doc:
if self.segmenter == Segmenter.jieba: if self.segmenter == Segmenter.jieba:
@ -154,14 +144,10 @@ class ChineseTokenizer(DummyTokenizer):
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> Dict[str, Any]:
return { return {
"segmenter": self.segmenter, "segmenter": self.segmenter,
"pkuseg_model": self.pkuseg_model,
"pkuseg_user_dict": self.pkuseg_user_dict,
} }
def _set_config(self, config: Dict[str, Any] = {}) -> None: def _set_config(self, config: Dict[str, Any] = {}) -> None:
self.segmenter = config.get("segmenter", Segmenter.char) self.segmenter = config.get("segmenter", Segmenter.char)
self.pkuseg_model = config.get("pkuseg_model", None)
self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default")
def to_bytes(self, **kwargs): def to_bytes(self, **kwargs):
pkuseg_features_b = b"" pkuseg_features_b = b""
@ -339,42 +325,33 @@ class Chinese(Language):
Defaults = ChineseDefaults Defaults = ChineseDefaults
def try_jieba_import(segmenter: str) -> None: def try_jieba_import() -> None:
try: try:
import jieba import jieba
if segmenter == Segmenter.jieba: # segment a short text to have jieba initialize its cache in advance
# segment a short text to have jieba initialize its cache in advance list(jieba.cut("作为", cut_all=False))
list(jieba.cut("作为", cut_all=False))
return jieba return jieba
except ImportError: except ImportError:
if segmenter == Segmenter.jieba: msg = (
msg = ( "Jieba not installed. To use jieba, install it with `pip "
"Jieba not installed. To use jieba, install it with `pip " " install jieba` or from https://github.com/fxsjy/jieba"
" install jieba` or from https://github.com/fxsjy/jieba" )
) raise ImportError(msg) from None
raise ImportError(msg) from None
def try_pkuseg_import( def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
segmenter: str, pkuseg_model: Optional[str], pkuseg_user_dict: str
) -> None:
try: try:
import pkuseg import pkuseg
if pkuseg_model is None: return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
return None
else:
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
except ImportError: except ImportError:
if segmenter == Segmenter.pkuseg: msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG raise ImportError(msg) from None
raise ImportError(msg) from None
except FileNotFoundError: except FileNotFoundError:
if segmenter == Segmenter.pkuseg: msg = "Unable to load pkuseg model from: " + pkuseg_model
msg = "Unable to load pkuseg model from: " + pkuseg_model raise FileNotFoundError(msg) from None
raise FileNotFoundError(msg) from None
def _get_pkuseg_trie_data(node, path=""): def _get_pkuseg_trie_data(node, path=""):

View File

@ -272,10 +272,14 @@ def zh_tokenizer_char():
def zh_tokenizer_jieba(): def zh_tokenizer_jieba():
pytest.importorskip("jieba") pytest.importorskip("jieba")
config = { config = {
"@tokenizers": "spacy.zh.ChineseTokenizer", "nlp": {
"segmenter": "jieba", "tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "jieba",
}
}
} }
nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}}) nlp = get_lang_class("zh").from_config(config)
return nlp.tokenizer return nlp.tokenizer
@ -290,7 +294,10 @@ def zh_tokenizer_pkuseg():
"segmenter": "pkuseg", "segmenter": "pkuseg",
} }
}, },
"initialize": {"tokenizer": {"pkuseg_model": "default"}}, "initialize": {"tokenizer": {
"pkuseg_model": "default",
}
},
} }
nlp = get_lang_class("zh").from_config(config) nlp = get_lang_class("zh").from_config(config)
nlp.initialize() nlp.initialize()

View File

@ -28,9 +28,17 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
@pytest.mark.slow @pytest.mark.slow
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
config = { config = {
"@tokenizers": "spacy.zh.ChineseTokenizer", "nlp": {
"segmenter": "pkuseg", "tokenizer": {
"pkuseg_model": "medicine", "@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {
"pkuseg_model": "medicine",
}
},
} }
nlp = Chinese.from_config({"nlp": {"tokenizer": config}}) nlp = Chinese.from_config(config)
nlp.initialize()
zh_tokenizer_serialize(nlp.tokenizer) zh_tokenizer_serialize(nlp.tokenizer)