diff --git a/spacy/language.py b/spacy/language.py index e62431bf1..6eb2d150b 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -31,6 +31,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP from .syntax.parser import get_templates from .syntax.nonproj import PseudoProjectivity from .pipeline import DependencyParser, EntityRecognizer +from .syntax.arc_eager import ArcEager +from .syntax.ner import BiluoPushDown class BaseDefaults(object): @@ -65,7 +67,7 @@ class BaseDefaults(object): prefix_search = util.compile_prefix_regex(cls.prefixes).search suffix_search = util.compile_suffix_regex(cls.suffixes).search infix_finditer = util.compile_infix_regex(cls.infixes).finditer - vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp) + vocab = nlp.vocab if nlp is not None else cls.Default.create_vocab(nlp) return Tokenizer(nlp.vocab, rules=rules, prefix_search=prefix_search, suffix_search=suffix_search, infix_finditer=infix_finditer) @@ -82,26 +84,27 @@ class BaseDefaults(object): return Tagger.load(nlp.path / 'pos', nlp.vocab) @classmethod - def create_parser(cls, nlp=None): + def create_parser(cls, nlp=None, **cfg): if nlp is None: - return DependencyParser(cls.create_vocab(), features=cls.parser_features) + return DependencyParser(cls.create_vocab(), features=cls.parser_features, + **cfg) elif nlp.path is False: - return DependencyParser(nlp.vocab, features=cls.parser_features) + return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg) elif nlp.path is None or not (nlp.path / 'deps').exists(): return None else: - return DependencyParser.load(nlp.path / 'deps', nlp.vocab) + return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg) @classmethod - def create_entity(cls, nlp=None): + def create_entity(cls, nlp=None, **cfg): if nlp is None: - return EntityRecognizer(cls.create_vocab(), features=cls.entity_features) + return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg) elif nlp.path is False: - return EntityRecognizer(nlp.vocab, features=cls.entity_features) + return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg) elif nlp.path is None or not (nlp.path / 'ner').exists(): return None else: - return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab) + return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg) @classmethod def create_matcher(cls, nlp=None): @@ -202,8 +205,8 @@ class Language(object): # preprocess training data here before ArcEager.get_labels() is called gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) - parser_cfg['labels'] = ArcEager.get_labels(gold_tuples) - entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples) + parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) + entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) with (dep_model_dir / 'config.json').open('wb') as file_: json.dump(parser_cfg, file_) @@ -224,22 +227,18 @@ class Language(object): vectors=False, pipeline=False) - self.defaults.parser_labels = parser_cfg['labels'] - self.defaults.entity_labels = entity_cfg['labels'] - - self.vocab = self.defaults.Vocab() - self.tokenizer = self.defaults.Tokenizer(self.vocab) - self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg) - self.parser = self.defaults.Parser(self.vocab, **parser_cfg) - self.entity = self.defaults.Entity(self.vocab, **entity_cfg) - self.pipeline = self.defaults.Pipeline(self) + self.vocab = self.Defaults.create_vocab(self) + self.tokenizer = self.Defaults.create_tokenizer(self) + self.tagger = self.Defaults.create_tagger(self) + self.parser = self.Defaults.create_parser(self) + self.entity = self.Defaults.create_entity(self) + self.pipeline = self.Defaults.create_pipeline(self) yield Trainer(self, gold_tuples) self.end_training() def __init__(self, path=True, **overrides): - if 'data_dir' in overrides and 'path' not in overrides: + if 'data_dir' in overrides and 'path' is True: raise ValueError("The argument 'data_dir' has been renamed to 'path'") - path = overrides.get('path', True) if isinstance(path, basestring): path = pathlib.Path(path) if path is True: @@ -253,7 +252,7 @@ class Language(object): add_vectors = self.Defaults.add_vectors(self) \ if 'add_vectors' not in overrides \ else overrides['add_vectors'] - if add_vectors: + if self.vocab and add_vectors: add_vectors(self.vocab) self.tokenizer = self.Defaults.create_tokenizer(self) \ if 'tokenizer' not in overrides \