Refactor Chinese initialization

This commit is contained in:
Adriane Boyd 2020-09-30 11:46:45 +02:00
parent 34f9c26c62
commit 6b7bb32834
4 changed files with 66 additions and 66 deletions

View File

@ -672,14 +672,22 @@ class Errors:
E999 = ("Unable to merge the `Doc` objects because they do not all share " E999 = ("Unable to merge the `Doc` objects because they do not all share "
"the same `Vocab`.") "the same `Vocab`.")
E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was " E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was "
"specified. Provide the name of a pretrained model or the path to " "loaded. Provide the name of a pretrained model or the path to "
"a model when initializing the pipeline:\n" "a model and initialize the pipeline:\n\n"
'config = {\n' 'config = {\n'
' "nlp": {\n'
' "tokenizer": {\n'
' "@tokenizers": "spacy.zh.ChineseTokenizer",\n' ' "@tokenizers": "spacy.zh.ChineseTokenizer",\n'
' "segmenter": "pkuseg",\n' ' "segmenter": "pkuseg",\n'
' "pkuseg_model": "default", # or "/path/to/pkuseg_model" \n'
' }\n' ' }\n'
'nlp = Chinese.from_config({"nlp": {"tokenizer": config}})') ' },\n'
' "initialize": {"tokenizer": {\n'
' "pkuseg_model": "default", # or /path/to/model\n'
' }\n'
' },\n'
'}\n'
'nlp = Chinese.from_config(config)\n'
'nlp.initialize()')
E1001 = ("Target token outside of matched span for match with tokens " E1001 = ("Target token outside of matched span for match with tokens "
"'{span}' and offset '{index}' matched by patterns '{patterns}'.") "'{span}' and offset '{index}' matched by patterns '{patterns}'.")
E1002 = ("Span index out of range.") E1002 = ("Span index out of range.")

View File

@ -59,32 +59,13 @@ class ChineseTokenizer(DummyTokenizer):
self, self,
nlp: Language, nlp: Language,
segmenter: Segmenter = Segmenter.char, segmenter: Segmenter = Segmenter.char,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None,
): ):
self.vocab = nlp.vocab self.vocab = nlp.vocab
if isinstance(segmenter, Segmenter): if isinstance(segmenter, Segmenter):
segmenter = segmenter.value segmenter = segmenter.value
self.segmenter = segmenter self.segmenter = segmenter
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.pkuseg_seg = None self.pkuseg_seg = None
self.jieba_seg = None self.jieba_seg = None
self.configure_segmenter(segmenter)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None
):
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.configure_segmenter(self.segmenter)
def configure_segmenter(self, segmenter: str):
if segmenter not in Segmenter.values(): if segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format( warn_msg = Warnings.W103.format(
lang="Chinese", lang="Chinese",
@ -94,11 +75,20 @@ class ChineseTokenizer(DummyTokenizer):
) )
warnings.warn(warn_msg) warnings.warn(warn_msg)
self.segmenter = Segmenter.char self.segmenter = Segmenter.char
self.jieba_seg = try_jieba_import(self.segmenter) if segmenter == Segmenter.jieba:
self.jieba_seg = try_jieba_import()
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: str = "default",
):
if self.segmenter == Segmenter.pkuseg:
self.pkuseg_seg = try_pkuseg_import( self.pkuseg_seg = try_pkuseg_import(
self.segmenter, pkuseg_model=pkuseg_model, pkuseg_user_dict=pkuseg_user_dict,
pkuseg_model=self.pkuseg_model,
pkuseg_user_dict=self.pkuseg_user_dict,
) )
def __call__(self, text: str) -> Doc: def __call__(self, text: str) -> Doc:
@ -154,14 +144,10 @@ class ChineseTokenizer(DummyTokenizer):
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> Dict[str, Any]:
return { return {
"segmenter": self.segmenter, "segmenter": self.segmenter,
"pkuseg_model": self.pkuseg_model,
"pkuseg_user_dict": self.pkuseg_user_dict,
} }
def _set_config(self, config: Dict[str, Any] = {}) -> None: def _set_config(self, config: Dict[str, Any] = {}) -> None:
self.segmenter = config.get("segmenter", Segmenter.char) self.segmenter = config.get("segmenter", Segmenter.char)
self.pkuseg_model = config.get("pkuseg_model", None)
self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default")
def to_bytes(self, **kwargs): def to_bytes(self, **kwargs):
pkuseg_features_b = b"" pkuseg_features_b = b""
@ -339,17 +325,15 @@ class Chinese(Language):
Defaults = ChineseDefaults Defaults = ChineseDefaults
def try_jieba_import(segmenter: str) -> None: def try_jieba_import() -> None:
try: try:
import jieba import jieba
if segmenter == Segmenter.jieba:
# segment a short text to have jieba initialize its cache in advance # segment a short text to have jieba initialize its cache in advance
list(jieba.cut("作为", cut_all=False)) list(jieba.cut("作为", cut_all=False))
return jieba return jieba
except ImportError: except ImportError:
if segmenter == Segmenter.jieba:
msg = ( msg = (
"Jieba not installed. To use jieba, install it with `pip " "Jieba not installed. To use jieba, install it with `pip "
" install jieba` or from https://github.com/fxsjy/jieba" " install jieba` or from https://github.com/fxsjy/jieba"
@ -357,22 +341,15 @@ def try_jieba_import(segmenter: str) -> None:
raise ImportError(msg) from None raise ImportError(msg) from None
def try_pkuseg_import( def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
segmenter: str, pkuseg_model: Optional[str], pkuseg_user_dict: str
) -> None:
try: try:
import pkuseg import pkuseg
if pkuseg_model is None:
return None
else:
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict) return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
except ImportError: except ImportError:
if segmenter == Segmenter.pkuseg:
msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
raise ImportError(msg) from None raise ImportError(msg) from None
except FileNotFoundError: except FileNotFoundError:
if segmenter == Segmenter.pkuseg:
msg = "Unable to load pkuseg model from: " + pkuseg_model msg = "Unable to load pkuseg model from: " + pkuseg_model
raise FileNotFoundError(msg) from None raise FileNotFoundError(msg) from None

View File

@ -272,10 +272,14 @@ def zh_tokenizer_char():
def zh_tokenizer_jieba(): def zh_tokenizer_jieba():
pytest.importorskip("jieba") pytest.importorskip("jieba")
config = { config = {
"nlp": {
"tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer", "@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "jieba", "segmenter": "jieba",
} }
nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}}) }
}
nlp = get_lang_class("zh").from_config(config)
return nlp.tokenizer return nlp.tokenizer
@ -290,7 +294,10 @@ def zh_tokenizer_pkuseg():
"segmenter": "pkuseg", "segmenter": "pkuseg",
} }
}, },
"initialize": {"tokenizer": {"pkuseg_model": "default"}}, "initialize": {"tokenizer": {
"pkuseg_model": "default",
}
},
} }
nlp = get_lang_class("zh").from_config(config) nlp = get_lang_class("zh").from_config(config)
nlp.initialize() nlp.initialize()

View File

@ -28,9 +28,17 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
@pytest.mark.slow @pytest.mark.slow
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
config = { config = {
"nlp": {
"tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer", "@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg", "segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {
"pkuseg_model": "medicine", "pkuseg_model": "medicine",
} }
nlp = Chinese.from_config({"nlp": {"tokenizer": config}}) },
}
nlp = Chinese.from_config(config)
nlp.initialize()
zh_tokenizer_serialize(nlp.tokenizer) zh_tokenizer_serialize(nlp.tokenizer)