Allow more components to use labels

This commit is contained in:
Matthew Honnibal 2020-09-29 16:48:56 +02:00
parent 99bff78617
commit 1fd002180e
2 changed files with 21 additions and 19 deletions

View File

@ -160,16 +160,12 @@ class TextCategorizer(Pipe):
self.cfg["labels"] = tuple(value) self.cfg["labels"] = tuple(value)
@property @property
def label_data(self) -> Dict: def label_data(self) -> List[str]:
"""RETURNS (Dict): Information about the component's labels. """RETURNS (List[str]): Information about the component's labels.
DOCS: https://nightly.spacy.io/api/textcategorizer#labels DOCS: https://nightly.spacy.io/api/textcategorizer#labels
""" """
return { return self.labels
"labels": self.labels,
"positive": self.cfg["positive_label"],
"threshold": self.cfg["threshold"]
}
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]: def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under """Apply the pipe to a stream of documents. This usually happens under
@ -354,6 +350,7 @@ class TextCategorizer(Pipe):
get_examples: Callable[[], Iterable[Example]], get_examples: Callable[[], Iterable[Example]],
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
labels: Optional[Dict] = None
): ):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -365,12 +362,14 @@ class TextCategorizer(Pipe):
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
subbatch = [] # Select a subbatch of examples to initialize the model if labels is None:
for example in islice(get_examples(), 10): for example in get_examples():
if len(subbatch) < 2: for cat in example.y.cats:
subbatch.append(example) self.add_label(cat)
for cat in example.y.cats: else:
self.add_label(cat) for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch] doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch) label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels() self._require_labels()

View File

@ -409,17 +409,20 @@ cdef class Parser(Pipe):
def set_output(self, nO): def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO) self.model.attrs["resize_output"](self.model, nO)
def initialize(self, get_examples, nlp=None): def initialize(self, get_examples, *, nlp=None, labels=None):
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
langs = ", ".join(util.LEXEME_NORM_LANGS) langs = ", ".join(util.LEXEME_NORM_LANGS)
util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs)) util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs))
actions = self.moves.get_actions( if labels is not None:
examples=get_examples(), actions = dict(labels)
min_freq=self.cfg['min_action_freq'], else:
learn_tokens=self.cfg["learn_tokens"] actions = self.moves.get_actions(
) examples=get_examples(),
min_freq=self.cfg['min_action_freq'],
learn_tokens=self.cfg["learn_tokens"]
)
for action, labels in self.moves.labels.items(): for action, labels in self.moves.labels.items():
actions.setdefault(action, {}) actions.setdefault(action, {})
for label, freq in labels.items(): for label, freq in labels.items():