diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index 00c7a9d07..55222f8aa 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -8,6 +8,7 @@ from typing import ( List, Optional, Protocol, + Sequence, Tuple, Union, overload, @@ -134,7 +135,12 @@ class Doc: def text(self) -> str: ... @property def text_with_ws(self) -> str: ... - ents: Tuple[Span] + # Ideally the getter would output Tuple[Span] + # see https://github.com/python/mypy/issues/3004 + @property + def ents(self) -> Sequence[Span]: ... + @ents.setter + def ents(self, value: Sequence[Span]) -> None: ... def set_ents( self, entities: List[Span], diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py index 6037c15e3..37af9e476 100644 --- a/spacy/training/corpus.py +++ b/spacy/training/corpus.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, import srsly from .. import util +from ..compat import Protocol from ..errors import Errors, Warnings from ..tokens import Doc, DocBin from ..vocab import Vocab @@ -19,6 +20,11 @@ if TYPE_CHECKING: FILE_TYPE = ".spacy" +class ReaderProtocol(Protocol): + def __call__(self, nlp: "Language") -> Iterable[Example]: + pass + + @util.registry.readers("spacy.Corpus.v1") def create_docbin_reader( path: Optional[Path], @@ -26,7 +32,7 @@ def create_docbin_reader( max_length: int = 0, limit: int = 0, augmenter: Optional[Callable] = None, -) -> Callable[["Language"], Iterable[Example]]: +) -> ReaderProtocol: if path is None: raise ValueError(Errors.E913) util.logger.debug("Loading corpus from path: %s", path) @@ -45,7 +51,7 @@ def create_jsonl_reader( min_length: int = 0, max_length: int = 0, limit: int = 0, -) -> Callable[["Language"], Iterable[Example]]: +) -> ReaderProtocol: return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit) @@ -63,7 +69,7 @@ def create_plain_text_reader( path: Optional[Path], min_length: int = 0, max_length: int = 0, -) -> Callable[["Language"], Iterable[Doc]]: +) -> ReaderProtocol: """Iterate Example objects from a file or directory of plain text UTF-8 files with one line per doc. @@ -144,7 +150,7 @@ class Corpus: self.augmenter = augmenter if augmenter is not None else dont_augment self.shuffle = shuffle - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -182,7 +188,7 @@ class Corpus: def make_examples( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterator[Example]: + ) -> Iterable[Example]: for reference in reference_docs: if len(reference) == 0: continue @@ -197,7 +203,7 @@ class Corpus: def make_examples_gold_preproc( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterator[Example]: + ) -> Iterable[Example]: for reference in reference_docs: if reference.has_annotation("SENT_START"): ref_sents = [sent.as_doc() for sent in reference.sents] @@ -210,7 +216,7 @@ class Corpus: def read_docbin( self, vocab: Vocab, locs: Iterable[Union[str, Path]] - ) -> Iterator[Doc]: + ) -> Iterable[Doc]: """Yield training examples as example dicts""" i = 0 for loc in locs: @@ -257,7 +263,7 @@ class JsonlCorpus: self.max_length = max_length self.limit = limit - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -307,7 +313,7 @@ class PlainTextCorpus: self.min_length = min_length self.max_length = max_length - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object. diff --git a/spacy/training/example.pyi b/spacy/training/example.pyi new file mode 100644 index 000000000..9cd563465 --- /dev/null +++ b/spacy/training/example.pyi @@ -0,0 +1,59 @@ +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple + +from ..tokens import Doc, Span +from ..vocab import Vocab +from .alignment import Alignment + +def annotations_to_doc( + vocab: Vocab, + tok_annot: Dict[str, Any], + doc_annot: Dict[str, Any], +) -> Doc: ... +def validate_examples( + examples: Iterable[Example], + method: str, +) -> None: ... +def validate_get_examples( + get_examples: Callable[[], Iterable[Example]], + method: str, +): ... + +class Example: + x: Doc + y: Doc + + def __init__( + self, + predicted: Doc, + reference: Doc, + *, + alignment: Optional[Alignment] = None, + ): ... + def __len__(self) -> int: ... + @property + def predicted(self) -> Doc: ... + @predicted.setter + def predicted(self, doc: Doc) -> None: ... + @property + def reference(self) -> Doc: ... + @reference.setter + def reference(self, doc: Doc) -> None: ... + def copy(self) -> Example: ... + @classmethod + def from_dict(cls, predicted: Doc, example_dict: Dict[str, Any]) -> Example: ... + @property + def alignment(self) -> Alignment: ... + def get_aligned(self, field: str, as_string=False): ... + def get_aligned_parse(self, projectivize=True): ... + def get_aligned_sent_starts(self): ... + def get_aligned_spans_x2y(self, x_spans: Sequence[Span], allow_overlap=False) -> List[Span]: ... + def get_aligned_spans_y2x(self, y_spans: Sequence[Span], allow_overlap=False) -> List[Span]: ... + def get_aligned_ents_and_ner(self) -> Tuple[List[Span], List[str]]: ... + def get_aligned_ner(self) -> List[str]: ... + def get_matching_ents(self, check_label: bool = True) -> List[Span]: ... + def to_dict(self) -> Dict[str, Any]: ... + def split_sents(self) -> List[Example]: ... + @property + def text(self) -> str: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ...