From 1fd002180e98d830da26f4593ce6bc7a838e2131 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 29 Sep 2020 16:48:56 +0200 Subject: [PATCH] Allow more components to use labels --- spacy/pipeline/textcat.py | 25 ++++++++++++------------- spacy/pipeline/transition_parser.pyx | 15 +++++++++------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 63b040333..d6dafa3f5 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -160,16 +160,12 @@ class TextCategorizer(Pipe): self.cfg["labels"] = tuple(value) @property - def label_data(self) -> Dict: - """RETURNS (Dict): Information about the component's labels. + def label_data(self) -> List[str]: + """RETURNS (List[str]): Information about the component's labels. DOCS: https://nightly.spacy.io/api/textcategorizer#labels """ - return { - "labels": self.labels, - "positive": self.cfg["positive_label"], - "threshold": self.cfg["threshold"] - } + return self.labels def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]: """Apply the pipe to a stream of documents. This usually happens under @@ -354,6 +350,7 @@ class TextCategorizer(Pipe): get_examples: Callable[[], Iterable[Example]], *, nlp: Optional[Language] = None, + labels: Optional[Dict] = None ): """Initialize the pipe for training, using a representative set of data examples. @@ -365,12 +362,14 @@ class TextCategorizer(Pipe): DOCS: https://nightly.spacy.io/api/textcategorizer#initialize """ self._ensure_examples(get_examples) - subbatch = [] # Select a subbatch of examples to initialize the model - for example in islice(get_examples(), 10): - if len(subbatch) < 2: - subbatch.append(example) - for cat in example.y.cats: - self.add_label(cat) + if labels is None: + for example in get_examples(): + for cat in example.y.cats: + self.add_label(cat) + else: + for label in labels: + self.add_label(label) + subbatch = list(islice(get_examples(), 10)) doc_sample = [eg.reference for eg in subbatch] label_sample, _ = self._examples_to_truth(subbatch) self._require_labels() diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 9f165cb15..11e0e5af8 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -409,17 +409,20 @@ cdef class Parser(Pipe): def set_output(self, nO): self.model.attrs["resize_output"](self.model, nO) - def initialize(self, get_examples, nlp=None): + def initialize(self, get_examples, *, nlp=None, labels=None): self._ensure_examples(get_examples) lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: langs = ", ".join(util.LEXEME_NORM_LANGS) util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs)) - actions = self.moves.get_actions( - examples=get_examples(), - min_freq=self.cfg['min_action_freq'], - learn_tokens=self.cfg["learn_tokens"] - ) + if labels is not None: + actions = dict(labels) + else: + actions = self.moves.get_actions( + examples=get_examples(), + min_freq=self.cfg['min_action_freq'], + learn_tokens=self.cfg["learn_tokens"] + ) for action, labels in self.moves.labels.items(): actions.setdefault(action, {}) for label, freq in labels.items():