From cde96f6c64220bf6a82cf4288f6e2bfbbc97eb0a Mon Sep 17 00:00:00 2001 From: Leander Fiedler Date: Mon, 6 Apr 2020 20:51:12 +0200 Subject: [PATCH] issue5230: optimized unit test a bit --- spacy/tests/regression/test_issue5230.py | 61 +++++++++--------------- 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/spacy/tests/regression/test_issue5230.py b/spacy/tests/regression/test_issue5230.py index 716a4624b..76d4d3e96 100644 --- a/spacy/tests/regression/test_issue5230.py +++ b/spacy/tests/regression/test_issue5230.py @@ -1,41 +1,28 @@ # coding: utf8 import warnings -import numpy import pytest import srsly - +from numpy import zeros from spacy.kb import KnowledgeBase from spacy.vectors import Vectors + from spacy.language import Language from spacy.pipeline import Pipe from spacy.tests.util import make_tempdir -def test_language_to_disk_resource_warning(): - nlp = Language() - with make_tempdir() as d: - with warnings.catch_warnings(record=True) as w: - # catch only warnings raised in spacy.language since there may be others from other components or pipelines - warnings.filterwarnings( - "always", module="spacy.language", category=ResourceWarning - ) - nlp.to_disk(d) - assert len(w) == 0 +def nlp(): + return Language() -def test_vectors_to_disk_resource_warning(): - data = numpy.zeros((3, 300), dtype="f") +def vectors(): + data = zeros((3, 1), dtype="f") keys = ["cat", "dog", "rat"] - vectors = Vectors(data=data, keys=keys) - with make_tempdir() as d: - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings("always", category=ResourceWarning) - vectors.to_disk(d) - assert len(w) == 0 + return Vectors(data=data, keys=keys) -def test_custom_pipes_to_disk_resource_warning(): +def custom_pipe(): # create dummy pipe partially implementing interface -- only want to test to_disk class SerializableDummy(object): def __init__(self, **cfg): @@ -66,15 +53,10 @@ def test_custom_pipes_to_disk_resource_warning(): self.model = SerializableDummy() self.vocab = SerializableDummy() - pipe = MyPipe(None) - with make_tempdir() as d: - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings("always", category=ResourceWarning) - pipe.to_disk(d) - assert len(w) == 0 + return MyPipe(None) -def test_tagger_to_disk_resource_warning(): +def tagger(): nlp = Language() nlp.add_pipe(nlp.create_pipe("tagger")) tagger = nlp.get_pipe("tagger") @@ -82,15 +64,10 @@ def test_tagger_to_disk_resource_warning(): # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization tagger.begin_training(pipeline=nlp.pipeline) - - with make_tempdir() as d: - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings("always", category=ResourceWarning) - tagger.to_disk(d) - assert len(w) == 0 + return tagger -def test_entity_linker_to_disk_resource_warning(): +def entity_linker(): nlp = Language() nlp.add_pipe(nlp.create_pipe("entity_linker")) entity_linker = nlp.get_pipe("entity_linker") @@ -100,9 +77,17 @@ def test_entity_linker_to_disk_resource_warning(): kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) entity_linker.set_kb(kb) entity_linker.begin_training(pipeline=nlp.pipeline) + return entity_linker + +@pytest.mark.parametrize( + "obj", + [nlp(), vectors(), custom_pipe(), tagger(), entity_linker()], + ids=["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"], +) +def test_to_disk_resource_warning(obj): with make_tempdir() as d: - with warnings.catch_warnings(record=True) as w: + with warnings.catch_warnings(record=True) as warnings_list: warnings.filterwarnings("always", category=ResourceWarning) - entity_linker.to_disk(d) - assert len(w) == 0 + obj.to_disk(d) + assert len(warnings_list) == 0