diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 4cdaf3d83..052bd2874 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -398,7 +398,9 @@ class SpanCategorizer(TrainablePipe): pass def _get_aligned_spans(self, eg: Example): - return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True) + return eg.get_aligned_spans_y2x( + eg.reference.spans.get(self.key, []), allow_overlap=True + ) def _make_span_group( self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 3da5816ab..7b759f8f6 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -85,7 +85,12 @@ def test_doc_gc(): spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat.add_label("PERSON") nlp.initialize() - texts = ["Just a sentence.", "I like London and Berlin", "I like Berlin", "I eat ham."] + texts = [ + "Just a sentence.", + "I like London and Berlin", + "I like Berlin", + "I eat ham.", + ] all_spans = [doc.spans for doc in nlp.pipe(texts)] for text, spangroups in zip(texts, all_spans): assert isinstance(spangroups, SpanGroups) @@ -338,7 +343,11 @@ def test_overfitting_IO_overlapping(): assert len(spans) == 3 assert len(spans.attrs["scores"]) == 3 assert min(spans.attrs["scores"]) > 0.9 - assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"} + assert set([span.text for span in spans]) == { + "London", + "Berlin", + "London and Berlin", + } assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"} # Also test the results are still the same after IO @@ -350,5 +359,9 @@ def test_overfitting_IO_overlapping(): assert len(spans2) == 3 assert len(spans2.attrs["scores"]) == 3 assert min(spans2.attrs["scores"]) > 0.9 - assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"} + assert set([span.text for span in spans2]) == { + "London", + "Berlin", + "London and Berlin", + } assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}