Allow `Closing` to detect dependent resources (#636)

This commit is contained in:
Jamie Stumme 2022-12-18 19:49:23 -07:00 committed by GitHub
parent a79ea1790c
commit 3b76a0d091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 0 deletions

View File

@ -593,6 +593,22 @@ def _fetch_reference_injections( # noqa: C901
return injections, closing
def _locate_dependent_closing_args(provider: providers.Provider) -> dict[str, providers.Provider]:
if not hasattr(provider, "args"):
return {}
closing_deps = {}
for arg in provider.args:
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
continue
if not arg.args and isinstance(arg, providers.Resource):
return {str(id(arg)): arg}
else:
closing_deps += _locate_dependent_closing_args(arg)
return closing_deps
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
patched_callable = _patched_registry.get_callable(fn)
if patched_callable is None:
@ -614,6 +630,9 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
if injection in patched_callable.reference_closing:
patched_callable.add_closing(injection, provider)
deps = _locate_dependent_closing_args(provider)
for key, dep in deps.items():
patched_callable.add_closing(key, dep)
def _unbind_injections(fn: Callable[..., Any]) -> None:

View File

@ -20,6 +20,11 @@ class Service:
cls.shutdown_counter += 1
class FactoryService:
def __init__(self, service: Service):
self.service = service
def init_service():
service = Service()
service.init()
@ -30,8 +35,14 @@ def init_service():
class Container(containers.DeclarativeContainer):
service = providers.Resource(init_service)
factory_service = providers.Factory(FactoryService, service)
@inject
def test_function(service: Service = Closing[Provide["service"]]):
return service
@inject
def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]):
return factory

View File

@ -289,6 +289,23 @@ def test_closing_resource():
assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function_dependency()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1
assert result_1.service.shutdown_counter == 1
result_2 = resourceclosing.test_function_dependency()
assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2
assert result_2.service.shutdown_counter == 2
assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container")
def test_closing_resource_bypass_marker_injection():
resourceclosing.Service.reset_counter()