Tidy up and add types

This commit is contained in:
Ines Montani 2020-09-23 10:14:34 +02:00
parent 6ca06cb62c
commit ae5dacf75f
1 changed files with 24 additions and 30 deletions

View File

@ -1,9 +1,11 @@
from typing import List, Tuple, Iterable, Union, Iterator
import warnings import warnings
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..tokens import Span from ..tokens import Span, Doc
def iob_to_biluo(tags): def iob_to_biluo(tags: Iterable[str]) -> List[str]:
out = [] out = []
tags = list(tags) tags = list(tags)
while tags: while tags:
@ -12,7 +14,7 @@ def iob_to_biluo(tags):
return out return out
def biluo_to_iob(tags): def biluo_to_iob(tags: Iterable[str]) -> List[str]:
out = [] out = []
for tag in tags: for tag in tags:
if tag is None: if tag is None:
@ -23,12 +25,12 @@ def biluo_to_iob(tags):
return out return out
def _consume_os(tags): def _consume_os(tags: List[str]) -> Iterator[str]:
while tags and tags[0] == "O": while tags and tags[0] == "O":
yield tags.pop(0) yield tags.pop(0)
def _consume_ent(tags): def _consume_ent(tags: List[str]) -> List[str]:
if not tags: if not tags:
return [] return []
tag = tags.pop(0) tag = tags.pop(0)
@ -50,11 +52,7 @@ def _consume_ent(tags):
return [start] + middle + [end] return [start] + middle + [end]
def biluo_tags_from_doc(doc, missing="O"): def doc_to_biluo_tags(doc: Doc, missing: str = "O"):
return doc_to_biluo_tags(doc, missing)
def doc_to_biluo_tags(doc, missing="O"):
return offsets_to_biluo_tags( return offsets_to_biluo_tags(
doc, doc,
[(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents], [(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents],
@ -62,11 +60,9 @@ def doc_to_biluo_tags(doc, missing="O"):
) )
def biluo_tags_from_offsets(doc, entities, missing="O"): def offsets_to_biluo_tags(
return offsets_to_biluo_tags(doc, entities, missing) doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O"
) -> List[str]:
def offsets_to_biluo_tags(doc, entities, missing="O"):
"""Encode labelled spans into per-token tags, using the """Encode labelled spans into per-token tags, using the
Begin/In/Last/Unit/Out scheme (BILUO). Begin/In/Last/Unit/Out scheme (BILUO).
@ -77,7 +73,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
the original string. the original string.
RETURNS (list): A list of unicode strings, describing the tags. Each tag RETURNS (list): A list of unicode strings, describing the tags. Each tag
string will be of the form either "", "O" or "{action}-{label}", where string will be of the form either "", "O" or "{action}-{label}", where
action is one of "B", "I", "L", "U". The string "-" is used where the action is one of "B", "I", "L", "U". The missing label is used where the
entity offsets don't align with the tokenization in the `Doc` object. entity offsets don't align with the tokenization in the `Doc` object.
The training algorithm will view these as missing values. "O" denotes a The training algorithm will view these as missing values. "O" denotes a
non-entity token. "B" denotes the beginning of a multi-token entity, non-entity token. "B" denotes the beginning of a multi-token entity,
@ -93,7 +89,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
""" """
# Ensure no overlapping entity labels exist # Ensure no overlapping entity labels exist
tokens_in_ents = {} tokens_in_ents = {}
starts = {token.idx: token.i for token in doc} starts = {token.idx: token.i for token in doc}
ends = {token.idx + len(token): token.i for token in doc} ends = {token.idx + len(token): token.i for token in doc}
biluo = ["-" for _ in doc] biluo = ["-" for _ in doc]
@ -117,7 +112,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
) )
) )
tokens_in_ents[token_index] = (start_char, end_char, label) tokens_in_ents[token_index] = (start_char, end_char, label)
start_token = starts.get(start_char) start_token = starts.get(start_char)
end_token = ends.get(end_char) end_token = ends.get(end_char)
# Only interested if the tokenization is correct # Only interested if the tokenization is correct
@ -151,11 +145,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
return biluo return biluo
def spans_from_biluo_tags(doc, tags): def biluo_tags_to_spans(doc: Doc, tags: Iterable[str]) -> List[Span]:
return biluo_tags_to_spans(doc, tags)
def biluo_tags_to_spans(doc, tags):
"""Encode per-token tags following the BILUO scheme into Span object, e.g. """Encode per-token tags following the BILUO scheme into Span object, e.g.
to overwrite the doc.ents. to overwrite the doc.ents.
@ -173,11 +163,9 @@ def biluo_tags_to_spans(doc, tags):
return spans return spans
def offsets_from_biluo_tags(doc, tags): def biluo_tags_to_offsets(
return biluo_tags_to_offsets(doc, tags) doc: Doc, tags: Iterable[str]
) -> List[Tuple[int, int, Union[str, int]]]:
def biluo_tags_to_offsets(doc, tags):
"""Encode per-token tags following the BILUO scheme into entity offsets. """Encode per-token tags following the BILUO scheme into entity offsets.
doc (Doc): The document that the BILUO tags refer to. doc (Doc): The document that the BILUO tags refer to.
@ -192,8 +180,8 @@ def biluo_tags_to_offsets(doc, tags):
return [(span.start_char, span.end_char, span.label_) for span in spans] return [(span.start_char, span.end_char, span.label_) for span in spans]
def tags_to_entities(tags): def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
""" Note that the end index returned by this function is inclusive. """Note that the end index returned by this function is inclusive.
To use it for Span creation, increment the end by 1.""" To use it for Span creation, increment the end by 1."""
entities = [] entities = []
start = None start = None
@ -225,3 +213,9 @@ def tags_to_entities(tags):
else: else:
raise ValueError(Errors.E068.format(tag=tag)) raise ValueError(Errors.E068.format(tag=tag))
return entities return entities
# Fallbacks to make backwards-compat easier
offsets_from_biluo_tags = biluo_tags_to_offsets
spans_from_biluo_tags = biluo_tags_to_spans
biluo_tags_from_offsets = offsets_to_biluo_tags