Refactor defaults

This commit is contained in:
Matthew Honnibal 2016-10-18 16:18:25 +02:00
parent a45a9d5092
commit 7d5212f131
1 changed files with 97 additions and 143 deletions

View File

@ -34,97 +34,79 @@ from .pipeline import DependencyParser, EntityRecognizer
class BaseDefaults(object): class BaseDefaults(object):
def __init__(self, lang, path): @classmethod
self.path = path def create_lemmatizer(cls, nlp=None):
self.lang = lang if nlp is None or nlp.path is None:
self.lex_attr_getters = dict(self.__class__.lex_attr_getters) return Lemmatizer({}, {}, {})
if self.path and (self.path / 'vocab' / 'oov_prob').exists(): else:
with (self.path / 'vocab' / 'oov_prob').open() as file_: return Lemmatizer.load(nlp.path)
oov_prob = file_.read().strip()
self.lex_attr_getters[PROB] = lambda string: oov_prob
self.lex_attr_getters[LANG] = lambda string: lang
self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words
def Lemmatizer(self): @classmethod
return Lemmatizer.load(self.path) if self.path else Lemmatizer({}, {}, {}) def create_vocab(cls, nlp=None):
lemmatizer = cls.create_lemmatizer(nlp)
def Vectors(self): if nlp is None or nlp.path is None:
return Vocab(lex_attr_getters=cls.lex_attr_getters, tag_map=cls.tag_map,
lemmatizer=lemmatizer)
else:
return Vocab.load(nlp.path, lex_attr_getters=cls.lex_attr_getters,
tag_map=cls.tag_map, lemmatizer=lemmatizer)
@classmethod
def add_vectors(cls, nlp=None):
return True return True
def Vocab(self, lex_attr_getters=True, tag_map=True, @classmethod
lemmatizer=True, serializer_freqs=True, vectors=True): def create_tokenizer(cls, nlp=None):
if lex_attr_getters is True: rules = cls.tokenizer_exceptions
lex_attr_getters = self.lex_attr_getters prefix_search = util.compile_prefix_regex(cls.prefixes).search
if tag_map is True: suffix_search = util.compile_suffix_regex(cls.suffixes).search
tag_map = self.tag_map infix_finditer = util.compile_infix_regex(cls.infixes).finditer
if lemmatizer is True: vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
lemmatizer = self.Lemmatizer() return Tokenizer(nlp.vocab, rules=rules,
if vectors is True: prefix_search=prefix_search, suffix_search=suffix_search,
vectors = self.Vectors() infix_finditer=infix_finditer)
if self.path:
return Vocab.load(self.path, lex_attr_getters=lex_attr_getters, @classmethod
tag_map=tag_map, lemmatizer=lemmatizer, def create_tagger(cls, nlp=None):
serializer_freqs=serializer_freqs) if nlp is None:
return Tagger(cls.create_vocab(), features=cls.tagger_features)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return Tagger(nlp.vocab, features=cls.tagger_features)
else: else:
return Vocab(lex_attr_getters=lex_attr_getters, tag_map=tag_map, return Tagger.load(nlp.path / 'ner', nlp.vocab)
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
def Tokenizer(self, vocab, rules=None, prefix_search=None, suffix_search=None, @classmethod
infix_finditer=None): def create_parser(cls, nlp=None):
if rules is None: if nlp is None:
rules = self.tokenizer_exceptions return DependencyParser(cls.create_vocab(), features=cls.parser_features)
if prefix_search is None: elif nlp.path is None or not (nlp.path / 'deps').exists():
prefix_search = util.compile_prefix_regex(self.prefixes).search return DependencyParser(nlp.vocab, features=cls.parser_features)
if suffix_search is None:
suffix_search = util.compile_suffix_regex(self.suffixes).search
if infix_finditer is None:
infix_finditer = util.compile_infix_regex(self.infixes).finditer
if self.path:
return Tokenizer.load(self.path, vocab, rules=rules,
prefix_search=prefix_search,
suffix_search=suffix_search,
infix_finditer=infix_finditer)
else: else:
tokenizer = Tokenizer(vocab, rules=rules, return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
return tokenizer
def Tagger(self, vocab, **cfg): @classmethod
if self.path: def create_entity(cls, nlp=None):
return Tagger.load(self.path / 'pos', vocab) if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
else: else:
if 'features' not in cfg: return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
cfg['features'] = self.parser_features
return Tagger(vocab, **cfg)
def Parser(self, vocab, **cfg): @classmethod
if self.path and (self.path / 'deps').exists(): def create_matcher(cls, nlp=None):
return DependencyParser.load(self.path / 'deps', vocab) if nlp is None:
return Matcher(cls.create_vocab())
elif nlp.path is None or not (nlp.path / 'vocab').exists():
return Matcher(nlp.vocab)
else: else:
if 'features' not in cfg: return Matcher.load(nlp.path / 'vocab', nlp.vocab)
cfg['features'] = self.parser_features
return DependencyParser(vocab, **cfg)
def Entity(self, vocab, **cfg): @classmethod
if self.path and (self.path / 'ner').exists(): def create_pipeline(self, nlp=None):
return EntityRecognizer.load(self.path / 'ner', vocab)
else:
if 'features' not in cfg:
cfg['features'] = self.entity_features
return EntityRecognizer(vocab, **cfg)
def Matcher(self, vocab, **cfg):
if self.path:
return Matcher.load(self.path, vocab)
else:
return Matcher(vocab)
def MakeDoc(self, nlp, **cfg):
return lambda text: nlp.tokenizer(text)
def Pipeline(self, nlp, **cfg):
pipeline = [] pipeline = []
if nlp is None:
return []
if nlp.tagger: if nlp.tagger:
pipeline.append(nlp.tagger) pipeline.append(nlp.tagger)
if nlp.parser: if nlp.parser:
@ -147,6 +129,8 @@ class BaseDefaults(object):
entity_features = get_templates('ner') entity_features = get_templates('ner')
tagger_features = Tagger.feature_templates # TODO -- fix this
stop_words = set() stop_words = set()
lex_attr_getters = { lex_attr_getters = {
@ -240,78 +224,48 @@ class Language(object):
yield Trainer(self, gold_tuples) yield Trainer(self, gold_tuples)
self.end_training() self.end_training()
def __init__(self, def __init__(self, path=True, **overrides):
path=True, if 'data_dir' in overrides and 'path' not in overrides:
vocab=True, raise ValueError("The argument 'data_dir' has been renamed to 'path'")
tokenizer=True, path = overrides.get('path', True)
tagger=True,
parser=True,
entity=True,
matcher=True,
serializer=True,
vectors=True,
make_doc=True,
pipeline=True,
defaults=True,
data_dir=None):
"""
A model can be specified:
1) by calling a Language subclass
- spacy.en.English()
2) by calling a Language subclass with data_dir
- spacy.en.English('my/model/root')
- spacy.en.English(data_dir='my/model/root')
3) by package name
- spacy.load('en_default')
- spacy.load('en_default==1.0.0')
4) by package name with a relocated package base
- spacy.load('en_default', via='/my/package/root')
- spacy.load('en_default==1.0.0', via='/my/package/root')
"""
if data_dir is not None and path is None:
warn("'data_dir' argument now named 'path'. Doing what you mean.")
path = data_dir
if isinstance(path, basestring): if isinstance(path, basestring):
path = pathlib.Path(path) path = pathlib.Path(path)
if path is True: if path is True:
path = util.match_best_version(self.lang, '', util.get_data_path()) path = util.match_best_version(self.lang, '', util.get_data_path())
self.path = path self.path = path
defaults = defaults if defaults is not True else self.get_defaults(self.path)
self.vocab = self.Defaults.create_vocab(self) \
self.defaults = defaults if 'vocab' not in overrides \
self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors) else overrides['vocab']
self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab) self.tokenizer = self.Defaults.create_tokenizer(self) \
self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab) if 'tokenizer' not in overrides \
self.entity = entity if entity is not True else defaults.Entity(self.vocab) else overrides['tokenizer']
self.parser = parser if parser is not True else defaults.Parser(self.vocab) self.tagger = self.Defaults.create_tagger(self) \
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab) if 'tagger' not in overrides \
else overrides['tagger']
self.parser = self.Defaults.create_tagger(self) \
if 'parser' not in overrides \
else overrides['parser']
self.entity = self.Defaults.create_entity(self) \
if 'entity' not in overrides \
else overrides['entity']
self.matcher = self.Defaults.create_matcher(self) \
if 'matcher' not in overrides \
else overrides['matcher']
if make_doc in (None, True, False): if 'make_doc' in overrides:
self.make_doc = defaults.MakeDoc(self) self.make_doc = overrides['make_doc']
elif 'create_make_doc' in overrides:
self.make_doc = overrides['create_make_doc']
else: else:
self.make_doc = make_doc self.make_doc = lambda text: self.tokenizer(text)
if pipeline in (None, False): if 'pipeline' in overrides:
self.pipeline = [] self.pipeline = overrides['pipeline']
elif pipeline is True: elif 'create_pipeline' in overrides:
self.pipeline = defaults.Pipeline(self) self.pipeline = overrides['create_pipeline']
else: else:
self.pipeline = pipeline(self) self.pipeline = [self.tagger, self.parser, self.matcher, self.entity]
def __reduce__(self):
args = (
self.path,
self.vocab,
self.tokenizer,
self.tagger,
self.parser,
self.entity,
self.matcher
)
return (self.__class__, args, None, None)
def __call__(self, text, tag=True, parse=True, entity=True): def __call__(self, text, tag=True, parse=True, entity=True):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,