mirror of https://github.com/explosion/spaCy.git
Refactor defaults
This commit is contained in:
parent
a45a9d5092
commit
7d5212f131
|
@ -34,97 +34,79 @@ from .pipeline import DependencyParser, EntityRecognizer
|
|||
|
||||
|
||||
class BaseDefaults(object):
|
||||
def __init__(self, lang, path):
|
||||
self.path = path
|
||||
self.lang = lang
|
||||
self.lex_attr_getters = dict(self.__class__.lex_attr_getters)
|
||||
if self.path and (self.path / 'vocab' / 'oov_prob').exists():
|
||||
with (self.path / 'vocab' / 'oov_prob').open() as file_:
|
||||
oov_prob = file_.read().strip()
|
||||
self.lex_attr_getters[PROB] = lambda string: oov_prob
|
||||
self.lex_attr_getters[LANG] = lambda string: lang
|
||||
self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words
|
||||
@classmethod
|
||||
def create_lemmatizer(cls, nlp=None):
|
||||
if nlp is None or nlp.path is None:
|
||||
return Lemmatizer({}, {}, {})
|
||||
else:
|
||||
return Lemmatizer.load(nlp.path)
|
||||
|
||||
def Lemmatizer(self):
|
||||
return Lemmatizer.load(self.path) if self.path else Lemmatizer({}, {}, {})
|
||||
@classmethod
|
||||
def create_vocab(cls, nlp=None):
|
||||
lemmatizer = cls.create_lemmatizer(nlp)
|
||||
if nlp is None or nlp.path is None:
|
||||
return Vocab(lex_attr_getters=cls.lex_attr_getters, tag_map=cls.tag_map,
|
||||
lemmatizer=lemmatizer)
|
||||
else:
|
||||
return Vocab.load(nlp.path, lex_attr_getters=cls.lex_attr_getters,
|
||||
tag_map=cls.tag_map, lemmatizer=lemmatizer)
|
||||
|
||||
def Vectors(self):
|
||||
@classmethod
|
||||
def add_vectors(cls, nlp=None):
|
||||
return True
|
||||
|
||||
def Vocab(self, lex_attr_getters=True, tag_map=True,
|
||||
lemmatizer=True, serializer_freqs=True, vectors=True):
|
||||
if lex_attr_getters is True:
|
||||
lex_attr_getters = self.lex_attr_getters
|
||||
if tag_map is True:
|
||||
tag_map = self.tag_map
|
||||
if lemmatizer is True:
|
||||
lemmatizer = self.Lemmatizer()
|
||||
if vectors is True:
|
||||
vectors = self.Vectors()
|
||||
if self.path:
|
||||
return Vocab.load(self.path, lex_attr_getters=lex_attr_getters,
|
||||
tag_map=tag_map, lemmatizer=lemmatizer,
|
||||
serializer_freqs=serializer_freqs)
|
||||
@classmethod
|
||||
def create_tokenizer(cls, nlp=None):
|
||||
rules = cls.tokenizer_exceptions
|
||||
prefix_search = util.compile_prefix_regex(cls.prefixes).search
|
||||
suffix_search = util.compile_suffix_regex(cls.suffixes).search
|
||||
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
|
||||
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
|
||||
return Tokenizer(nlp.vocab, rules=rules,
|
||||
prefix_search=prefix_search, suffix_search=suffix_search,
|
||||
infix_finditer=infix_finditer)
|
||||
|
||||
@classmethod
|
||||
def create_tagger(cls, nlp=None):
|
||||
if nlp is None:
|
||||
return Tagger(cls.create_vocab(), features=cls.tagger_features)
|
||||
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
||||
return Tagger(nlp.vocab, features=cls.tagger_features)
|
||||
else:
|
||||
return Vocab(lex_attr_getters=lex_attr_getters, tag_map=tag_map,
|
||||
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
|
||||
return Tagger.load(nlp.path / 'ner', nlp.vocab)
|
||||
|
||||
def Tokenizer(self, vocab, rules=None, prefix_search=None, suffix_search=None,
|
||||
infix_finditer=None):
|
||||
if rules is None:
|
||||
rules = self.tokenizer_exceptions
|
||||
if prefix_search is None:
|
||||
prefix_search = util.compile_prefix_regex(self.prefixes).search
|
||||
if suffix_search is None:
|
||||
suffix_search = util.compile_suffix_regex(self.suffixes).search
|
||||
if infix_finditer is None:
|
||||
infix_finditer = util.compile_infix_regex(self.infixes).finditer
|
||||
if self.path:
|
||||
return Tokenizer.load(self.path, vocab, rules=rules,
|
||||
prefix_search=prefix_search,
|
||||
suffix_search=suffix_search,
|
||||
infix_finditer=infix_finditer)
|
||||
@classmethod
|
||||
def create_parser(cls, nlp=None):
|
||||
if nlp is None:
|
||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
|
||||
elif nlp.path is None or not (nlp.path / 'deps').exists():
|
||||
return DependencyParser(nlp.vocab, features=cls.parser_features)
|
||||
else:
|
||||
tokenizer = Tokenizer(vocab, rules=rules,
|
||||
prefix_search=prefix_search, suffix_search=suffix_search,
|
||||
infix_finditer=infix_finditer)
|
||||
return tokenizer
|
||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
|
||||
|
||||
def Tagger(self, vocab, **cfg):
|
||||
if self.path:
|
||||
return Tagger.load(self.path / 'pos', vocab)
|
||||
@classmethod
|
||||
def create_entity(cls, nlp=None):
|
||||
if nlp is None:
|
||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
|
||||
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
|
||||
else:
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.parser_features
|
||||
return Tagger(vocab, **cfg)
|
||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
|
||||
|
||||
def Parser(self, vocab, **cfg):
|
||||
if self.path and (self.path / 'deps').exists():
|
||||
return DependencyParser.load(self.path / 'deps', vocab)
|
||||
@classmethod
|
||||
def create_matcher(cls, nlp=None):
|
||||
if nlp is None:
|
||||
return Matcher(cls.create_vocab())
|
||||
elif nlp.path is None or not (nlp.path / 'vocab').exists():
|
||||
return Matcher(nlp.vocab)
|
||||
else:
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.parser_features
|
||||
return DependencyParser(vocab, **cfg)
|
||||
return Matcher.load(nlp.path / 'vocab', nlp.vocab)
|
||||
|
||||
def Entity(self, vocab, **cfg):
|
||||
if self.path and (self.path / 'ner').exists():
|
||||
return EntityRecognizer.load(self.path / 'ner', vocab)
|
||||
else:
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.entity_features
|
||||
return EntityRecognizer(vocab, **cfg)
|
||||
|
||||
def Matcher(self, vocab, **cfg):
|
||||
if self.path:
|
||||
return Matcher.load(self.path, vocab)
|
||||
else:
|
||||
return Matcher(vocab)
|
||||
|
||||
def MakeDoc(self, nlp, **cfg):
|
||||
return lambda text: nlp.tokenizer(text)
|
||||
|
||||
def Pipeline(self, nlp, **cfg):
|
||||
@classmethod
|
||||
def create_pipeline(self, nlp=None):
|
||||
pipeline = []
|
||||
if nlp is None:
|
||||
return []
|
||||
if nlp.tagger:
|
||||
pipeline.append(nlp.tagger)
|
||||
if nlp.parser:
|
||||
|
@ -147,6 +129,8 @@ class BaseDefaults(object):
|
|||
|
||||
entity_features = get_templates('ner')
|
||||
|
||||
tagger_features = Tagger.feature_templates # TODO -- fix this
|
||||
|
||||
stop_words = set()
|
||||
|
||||
lex_attr_getters = {
|
||||
|
@ -240,78 +224,48 @@ class Language(object):
|
|||
yield Trainer(self, gold_tuples)
|
||||
self.end_training()
|
||||
|
||||
def __init__(self,
|
||||
path=True,
|
||||
vocab=True,
|
||||
tokenizer=True,
|
||||
tagger=True,
|
||||
parser=True,
|
||||
entity=True,
|
||||
matcher=True,
|
||||
serializer=True,
|
||||
vectors=True,
|
||||
make_doc=True,
|
||||
pipeline=True,
|
||||
defaults=True,
|
||||
data_dir=None):
|
||||
"""
|
||||
A model can be specified:
|
||||
|
||||
1) by calling a Language subclass
|
||||
- spacy.en.English()
|
||||
|
||||
2) by calling a Language subclass with data_dir
|
||||
- spacy.en.English('my/model/root')
|
||||
- spacy.en.English(data_dir='my/model/root')
|
||||
|
||||
3) by package name
|
||||
- spacy.load('en_default')
|
||||
- spacy.load('en_default==1.0.0')
|
||||
|
||||
4) by package name with a relocated package base
|
||||
- spacy.load('en_default', via='/my/package/root')
|
||||
- spacy.load('en_default==1.0.0', via='/my/package/root')
|
||||
"""
|
||||
if data_dir is not None and path is None:
|
||||
warn("'data_dir' argument now named 'path'. Doing what you mean.")
|
||||
path = data_dir
|
||||
def __init__(self, path=True, **overrides):
|
||||
if 'data_dir' in overrides and 'path' not in overrides:
|
||||
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
|
||||
path = overrides.get('path', True)
|
||||
if isinstance(path, basestring):
|
||||
path = pathlib.Path(path)
|
||||
if path is True:
|
||||
path = util.match_best_version(self.lang, '', util.get_data_path())
|
||||
|
||||
self.path = path
|
||||
defaults = defaults if defaults is not True else self.get_defaults(self.path)
|
||||
|
||||
self.defaults = defaults
|
||||
self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors)
|
||||
self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab)
|
||||
self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab)
|
||||
self.entity = entity if entity is not True else defaults.Entity(self.vocab)
|
||||
self.parser = parser if parser is not True else defaults.Parser(self.vocab)
|
||||
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
|
||||
self.vocab = self.Defaults.create_vocab(self) \
|
||||
if 'vocab' not in overrides \
|
||||
else overrides['vocab']
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||
if 'tokenizer' not in overrides \
|
||||
else overrides['tokenizer']
|
||||
self.tagger = self.Defaults.create_tagger(self) \
|
||||
if 'tagger' not in overrides \
|
||||
else overrides['tagger']
|
||||
self.parser = self.Defaults.create_tagger(self) \
|
||||
if 'parser' not in overrides \
|
||||
else overrides['parser']
|
||||
self.entity = self.Defaults.create_entity(self) \
|
||||
if 'entity' not in overrides \
|
||||
else overrides['entity']
|
||||
self.matcher = self.Defaults.create_matcher(self) \
|
||||
if 'matcher' not in overrides \
|
||||
else overrides['matcher']
|
||||
|
||||
if make_doc in (None, True, False):
|
||||
self.make_doc = defaults.MakeDoc(self)
|
||||
if 'make_doc' in overrides:
|
||||
self.make_doc = overrides['make_doc']
|
||||
elif 'create_make_doc' in overrides:
|
||||
self.make_doc = overrides['create_make_doc']
|
||||
else:
|
||||
self.make_doc = make_doc
|
||||
if pipeline in (None, False):
|
||||
self.pipeline = []
|
||||
elif pipeline is True:
|
||||
self.pipeline = defaults.Pipeline(self)
|
||||
self.make_doc = lambda text: self.tokenizer(text)
|
||||
if 'pipeline' in overrides:
|
||||
self.pipeline = overrides['pipeline']
|
||||
elif 'create_pipeline' in overrides:
|
||||
self.pipeline = overrides['create_pipeline']
|
||||
else:
|
||||
self.pipeline = pipeline(self)
|
||||
|
||||
def __reduce__(self):
|
||||
args = (
|
||||
self.path,
|
||||
self.vocab,
|
||||
self.tokenizer,
|
||||
self.tagger,
|
||||
self.parser,
|
||||
self.entity,
|
||||
self.matcher
|
||||
)
|
||||
return (self.__class__, args, None, None)
|
||||
self.pipeline = [self.tagger, self.parser, self.matcher, self.entity]
|
||||
|
||||
def __call__(self, text, tag=True, parse=True, entity=True):
|
||||
"""Apply the pipeline to some text. The text can span multiple sentences,
|
||||
|
|
Loading…
Reference in New Issue