diff --git a/spacy/training/iob_utils.py b/spacy/training/iob_utils.py index 03a502912..91fc40205 100644 --- a/spacy/training/iob_utils.py +++ b/spacy/training/iob_utils.py @@ -1,9 +1,11 @@ +from typing import List, Tuple, Iterable, Union, Iterator import warnings + from ..errors import Errors, Warnings -from ..tokens import Span +from ..tokens import Span, Doc -def iob_to_biluo(tags): +def iob_to_biluo(tags: Iterable[str]) -> List[str]: out = [] tags = list(tags) while tags: @@ -12,7 +14,7 @@ def iob_to_biluo(tags): return out -def biluo_to_iob(tags): +def biluo_to_iob(tags: Iterable[str]) -> List[str]: out = [] for tag in tags: if tag is None: @@ -23,12 +25,12 @@ def biluo_to_iob(tags): return out -def _consume_os(tags): +def _consume_os(tags: List[str]) -> Iterator[str]: while tags and tags[0] == "O": yield tags.pop(0) -def _consume_ent(tags): +def _consume_ent(tags: List[str]) -> List[str]: if not tags: return [] tag = tags.pop(0) @@ -50,11 +52,7 @@ def _consume_ent(tags): return [start] + middle + [end] -def biluo_tags_from_doc(doc, missing="O"): - return doc_to_biluo_tags(doc, missing) - - -def doc_to_biluo_tags(doc, missing="O"): +def doc_to_biluo_tags(doc: Doc, missing: str = "O"): return offsets_to_biluo_tags( doc, [(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents], @@ -62,11 +60,9 @@ def doc_to_biluo_tags(doc, missing="O"): ) -def biluo_tags_from_offsets(doc, entities, missing="O"): - return offsets_to_biluo_tags(doc, entities, missing) - - -def offsets_to_biluo_tags(doc, entities, missing="O"): +def offsets_to_biluo_tags( + doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O" +) -> List[str]: """Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out scheme (BILUO). @@ -77,7 +73,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"): the original string. RETURNS (list): A list of unicode strings, describing the tags. Each tag string will be of the form either "", "O" or "{action}-{label}", where - action is one of "B", "I", "L", "U". The string "-" is used where the + action is one of "B", "I", "L", "U". The missing label is used where the entity offsets don't align with the tokenization in the `Doc` object. The training algorithm will view these as missing values. "O" denotes a non-entity token. "B" denotes the beginning of a multi-token entity, @@ -93,7 +89,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"): """ # Ensure no overlapping entity labels exist tokens_in_ents = {} - starts = {token.idx: token.i for token in doc} ends = {token.idx + len(token): token.i for token in doc} biluo = ["-" for _ in doc] @@ -117,7 +112,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"): ) ) tokens_in_ents[token_index] = (start_char, end_char, label) - start_token = starts.get(start_char) end_token = ends.get(end_char) # Only interested if the tokenization is correct @@ -151,11 +145,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"): return biluo -def spans_from_biluo_tags(doc, tags): - return biluo_tags_to_spans(doc, tags) - - -def biluo_tags_to_spans(doc, tags): +def biluo_tags_to_spans(doc: Doc, tags: Iterable[str]) -> List[Span]: """Encode per-token tags following the BILUO scheme into Span object, e.g. to overwrite the doc.ents. @@ -173,11 +163,9 @@ def biluo_tags_to_spans(doc, tags): return spans -def offsets_from_biluo_tags(doc, tags): - return biluo_tags_to_offsets(doc, tags) - - -def biluo_tags_to_offsets(doc, tags): +def biluo_tags_to_offsets( + doc: Doc, tags: Iterable[str] +) -> List[Tuple[int, int, Union[str, int]]]: """Encode per-token tags following the BILUO scheme into entity offsets. doc (Doc): The document that the BILUO tags refer to. @@ -192,8 +180,8 @@ def biluo_tags_to_offsets(doc, tags): return [(span.start_char, span.end_char, span.label_) for span in spans] -def tags_to_entities(tags): - """ Note that the end index returned by this function is inclusive. +def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]: + """Note that the end index returned by this function is inclusive. To use it for Span creation, increment the end by 1.""" entities = [] start = None @@ -225,3 +213,9 @@ def tags_to_entities(tags): else: raise ValueError(Errors.E068.format(tag=tag)) return entities + + +# Fallbacks to make backwards-compat easier +offsets_from_biluo_tags = biluo_tags_to_offsets +spans_from_biluo_tags = biluo_tags_to_spans +biluo_tags_from_offsets = offsets_to_biluo_tags