WIP: Test updating Chinese tokenizer

This commit is contained in:
Ines Montani 2020-09-29 21:10:22 +02:00
parent 4f3102d09c
commit 6467a560e3
2 changed files with 29 additions and 17 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Callable, Iterable
from enum import Enum from enum import Enum
import tempfile import tempfile
import srsly import srsly
@ -10,7 +10,7 @@ from ...errors import Warnings, Errors
from ...language import Language from ...language import Language
from ...scorer import Scorer from ...scorer import Scorer
from ...tokens import Doc from ...tokens import Doc
from ...training import validate_examples from ...training import validate_examples, Example
from ...util import DummyTokenizer, registry from ...util import DummyTokenizer, registry
from .lex_attrs import LEX_ATTRS from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS from .stop_words import STOP_WORDS
@ -28,6 +28,10 @@ DEFAULT_CONFIG = """
[nlp.tokenizer] [nlp.tokenizer]
@tokenizers = "spacy.zh.ChineseTokenizer" @tokenizers = "spacy.zh.ChineseTokenizer"
segmenter = "char" segmenter = "char"
[initialize]
[initialize.tokenizer]
pkuseg_model = null pkuseg_model = null
pkuseg_user_dict = "default" pkuseg_user_dict = "default"
""" """
@ -44,18 +48,9 @@ class Segmenter(str, Enum):
@registry.tokenizers("spacy.zh.ChineseTokenizer") @registry.tokenizers("spacy.zh.ChineseTokenizer")
def create_chinese_tokenizer( def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char,):
segmenter: Segmenter = Segmenter.char,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = "default",
):
def chinese_tokenizer_factory(nlp): def chinese_tokenizer_factory(nlp):
return ChineseTokenizer( return ChineseTokenizer(nlp, segmenter=segmenter)
nlp,
segmenter=segmenter,
pkuseg_model=pkuseg_model,
pkuseg_user_dict=pkuseg_user_dict,
)
return chinese_tokenizer_factory return chinese_tokenizer_factory
@ -78,6 +73,18 @@ class ChineseTokenizer(DummyTokenizer):
self.jieba_seg = None self.jieba_seg = None
self.configure_segmenter(segmenter) self.configure_segmenter(segmenter)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None
):
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.configure_segmenter(self.segmenter)
def configure_segmenter(self, segmenter: str): def configure_segmenter(self, segmenter: str):
if segmenter not in Segmenter.values(): if segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format( warn_msg = Warnings.W103.format(

View File

@ -284,11 +284,16 @@ def zh_tokenizer_pkuseg():
pytest.importorskip("pkuseg") pytest.importorskip("pkuseg")
pytest.importorskip("pickle5") pytest.importorskip("pickle5")
config = { config = {
"@tokenizers": "spacy.zh.ChineseTokenizer", "nlp": {
"segmenter": "pkuseg", "tokenizer": {
"pkuseg_model": "default", "@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {"pkuseg_model": "default"}},
} }
nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}}) nlp = get_lang_class("zh").from_config(config)
nlp.initialize()
return nlp.tokenizer return nlp.tokenizer