Auto-format code with black (#9065)

Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
This commit is contained in:
github-actions[bot] 2021-08-27 11:42:27 +02:00 committed by GitHub
parent 4d39430b82
commit fb9c31fbda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 4 deletions

View File

@ -398,7 +398,9 @@ class SpanCategorizer(TrainablePipe):
pass
def _get_aligned_spans(self, eg: Example):
return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []), allow_overlap=True)
return eg.get_aligned_spans_y2x(
eg.reference.spans.get(self.key, []), allow_overlap=True
)
def _make_span_group(
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]

View File

@ -85,7 +85,12 @@ def test_doc_gc():
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat.add_label("PERSON")
nlp.initialize()
texts = ["Just a sentence.", "I like London and Berlin", "I like Berlin", "I eat ham."]
texts = [
"Just a sentence.",
"I like London and Berlin",
"I like Berlin",
"I eat ham.",
]
all_spans = [doc.spans for doc in nlp.pipe(texts)]
for text, spangroups in zip(texts, all_spans):
assert isinstance(spangroups, SpanGroups)
@ -338,7 +343,11 @@ def test_overfitting_IO_overlapping():
assert len(spans) == 3
assert len(spans.attrs["scores"]) == 3
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
assert set([span.text for span in spans]) == {
"London",
"Berlin",
"London and Berlin",
}
assert set([span.label_ for span in spans]) == {"LOC", "DOUBLE_LOC"}
# Also test the results are still the same after IO
@ -350,5 +359,9 @@ def test_overfitting_IO_overlapping():
assert len(spans2) == 3
assert len(spans2.attrs["scores"]) == 3
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
assert set([span.text for span in spans2]) == {
"London",
"Berlin",
"London and Berlin",
}
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}