mirror of https://github.com/explosion/spaCy.git
Allow more components to use labels
This commit is contained in:
parent
99bff78617
commit
1fd002180e
|
@ -160,16 +160,12 @@ class TextCategorizer(Pipe):
|
|||
self.cfg["labels"] = tuple(value)
|
||||
|
||||
@property
|
||||
def label_data(self) -> Dict:
|
||||
"""RETURNS (Dict): Information about the component's labels.
|
||||
def label_data(self) -> List[str]:
|
||||
"""RETURNS (List[str]): Information about the component's labels.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#labels
|
||||
"""
|
||||
return {
|
||||
"labels": self.labels,
|
||||
"positive": self.cfg["positive_label"],
|
||||
"threshold": self.cfg["threshold"]
|
||||
}
|
||||
return self.labels
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
|
@ -354,6 +350,7 @@ class TextCategorizer(Pipe):
|
|||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
labels: Optional[Dict] = None
|
||||
):
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
@ -365,12 +362,14 @@ class TextCategorizer(Pipe):
|
|||
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
||||
"""
|
||||
self._ensure_examples(get_examples)
|
||||
subbatch = [] # Select a subbatch of examples to initialize the model
|
||||
for example in islice(get_examples(), 10):
|
||||
if len(subbatch) < 2:
|
||||
subbatch.append(example)
|
||||
for cat in example.y.cats:
|
||||
self.add_label(cat)
|
||||
if labels is None:
|
||||
for example in get_examples():
|
||||
for cat in example.y.cats:
|
||||
self.add_label(cat)
|
||||
else:
|
||||
for label in labels:
|
||||
self.add_label(label)
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
label_sample, _ = self._examples_to_truth(subbatch)
|
||||
self._require_labels()
|
||||
|
|
|
@ -409,17 +409,20 @@ cdef class Parser(Pipe):
|
|||
def set_output(self, nO):
|
||||
self.model.attrs["resize_output"](self.model, nO)
|
||||
|
||||
def initialize(self, get_examples, nlp=None):
|
||||
def initialize(self, get_examples, *, nlp=None, labels=None):
|
||||
self._ensure_examples(get_examples)
|
||||
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
|
||||
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
|
||||
langs = ", ".join(util.LEXEME_NORM_LANGS)
|
||||
util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs))
|
||||
actions = self.moves.get_actions(
|
||||
examples=get_examples(),
|
||||
min_freq=self.cfg['min_action_freq'],
|
||||
learn_tokens=self.cfg["learn_tokens"]
|
||||
)
|
||||
if labels is not None:
|
||||
actions = dict(labels)
|
||||
else:
|
||||
actions = self.moves.get_actions(
|
||||
examples=get_examples(),
|
||||
min_freq=self.cfg['min_action_freq'],
|
||||
learn_tokens=self.cfg["learn_tokens"]
|
||||
)
|
||||
for action, labels in self.moves.labels.items():
|
||||
actions.setdefault(action, {})
|
||||
for label, freq in labels.items():
|
||||
|
|
Loading…
Reference in New Issue