Tidy up and add types

This commit is contained in:
Ines Montani 2020-09-23 10:14:34 +02:00
parent 6ca06cb62c
commit ae5dacf75f
1 changed files with 24 additions and 30 deletions

View File

@ -1,9 +1,11 @@
from typing import List, Tuple, Iterable, Union, Iterator
import warnings
from ..errors import Errors, Warnings
from ..tokens import Span
from ..tokens import Span, Doc
def iob_to_biluo(tags):
def iob_to_biluo(tags: Iterable[str]) -> List[str]:
out = []
tags = list(tags)
while tags:
@ -12,7 +14,7 @@ def iob_to_biluo(tags):
return out
def biluo_to_iob(tags):
def biluo_to_iob(tags: Iterable[str]) -> List[str]:
out = []
for tag in tags:
if tag is None:
@ -23,12 +25,12 @@ def biluo_to_iob(tags):
return out
def _consume_os(tags):
def _consume_os(tags: List[str]) -> Iterator[str]:
while tags and tags[0] == "O":
yield tags.pop(0)
def _consume_ent(tags):
def _consume_ent(tags: List[str]) -> List[str]:
if not tags:
return []
tag = tags.pop(0)
@ -50,11 +52,7 @@ def _consume_ent(tags):
return [start] + middle + [end]
def biluo_tags_from_doc(doc, missing="O"):
return doc_to_biluo_tags(doc, missing)
def doc_to_biluo_tags(doc, missing="O"):
def doc_to_biluo_tags(doc: Doc, missing: str = "O"):
return offsets_to_biluo_tags(
doc,
[(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents],
@ -62,11 +60,9 @@ def doc_to_biluo_tags(doc, missing="O"):
)
def biluo_tags_from_offsets(doc, entities, missing="O"):
return offsets_to_biluo_tags(doc, entities, missing)
def offsets_to_biluo_tags(doc, entities, missing="O"):
def offsets_to_biluo_tags(
doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O"
) -> List[str]:
"""Encode labelled spans into per-token tags, using the
Begin/In/Last/Unit/Out scheme (BILUO).
@ -77,7 +73,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
the original string.
RETURNS (list): A list of unicode strings, describing the tags. Each tag
string will be of the form either "", "O" or "{action}-{label}", where
action is one of "B", "I", "L", "U". The string "-" is used where the
action is one of "B", "I", "L", "U". The missing label is used where the
entity offsets don't align with the tokenization in the `Doc` object.
The training algorithm will view these as missing values. "O" denotes a
non-entity token. "B" denotes the beginning of a multi-token entity,
@ -93,7 +89,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
"""
# Ensure no overlapping entity labels exist
tokens_in_ents = {}
starts = {token.idx: token.i for token in doc}
ends = {token.idx + len(token): token.i for token in doc}
biluo = ["-" for _ in doc]
@ -117,7 +112,6 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
)
)
tokens_in_ents[token_index] = (start_char, end_char, label)
start_token = starts.get(start_char)
end_token = ends.get(end_char)
# Only interested if the tokenization is correct
@ -151,11 +145,7 @@ def offsets_to_biluo_tags(doc, entities, missing="O"):
return biluo
def spans_from_biluo_tags(doc, tags):
return biluo_tags_to_spans(doc, tags)
def biluo_tags_to_spans(doc, tags):
def biluo_tags_to_spans(doc: Doc, tags: Iterable[str]) -> List[Span]:
"""Encode per-token tags following the BILUO scheme into Span object, e.g.
to overwrite the doc.ents.
@ -173,11 +163,9 @@ def biluo_tags_to_spans(doc, tags):
return spans
def offsets_from_biluo_tags(doc, tags):
return biluo_tags_to_offsets(doc, tags)
def biluo_tags_to_offsets(doc, tags):
def biluo_tags_to_offsets(
doc: Doc, tags: Iterable[str]
) -> List[Tuple[int, int, Union[str, int]]]:
"""Encode per-token tags following the BILUO scheme into entity offsets.
doc (Doc): The document that the BILUO tags refer to.
@ -192,8 +180,8 @@ def biluo_tags_to_offsets(doc, tags):
return [(span.start_char, span.end_char, span.label_) for span in spans]
def tags_to_entities(tags):
""" Note that the end index returned by this function is inclusive.
def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
"""Note that the end index returned by this function is inclusive.
To use it for Span creation, increment the end by 1."""
entities = []
start = None
@ -225,3 +213,9 @@ def tags_to_entities(tags):
else:
raise ValueError(Errors.E068.format(tag=tag))
return entities
# Fallbacks to make backwards-compat easier
offsets_from_biluo_tags = biluo_tags_to_offsets
spans_from_biluo_tags = biluo_tags_to_spans
biluo_tags_from_offsets = offsets_to_biluo_tags