From 41e18d2c893f56c8fe9b718477a56dd755079635 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Thu, 28 Jan 2021 19:18:47 -0500 Subject: [PATCH] Implement wiring autoloader --- src/dependency_injector/wiring.py | 101 ++++++++++++++++++++++++++ tests/unit/wiring/test_wiring_py36.py | 36 ++++++++- 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index c738bf02..81fa3985 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -4,6 +4,7 @@ import asyncio import functools import inspect import importlib +import importlib.machinery import pkgutil import sys from types import ModuleType @@ -52,6 +53,11 @@ __all__ = ( 'Provide', 'Provider', 'Closing', + 'register_loader_containers', + 'unregister_loader_containers', + 'install_loader', + 'uninstall_loader', + 'is_loader_installed', ) T = TypeVar('T') @@ -535,3 +541,98 @@ class Provider(_Marker): class Closing(_Marker): ... + + +class AutoLoader: + """Auto-wiring module loader. + + Automatically wire containers when modules are imported. + """ + + def __init__(self): + self.containers = [] + self._path_hook = None + + def register_containers(self, *containers): + self.containers.extend(containers) + + if not self.installed: + self.install() + + def unregister_containers(self, *containers): + for container in containers: + self.containers.remove(container) + + if not self.containers: + self.uninstall() + + def wire_module(self, module): + for container in self.containers: + container.wire(modules=[module]) + + @property + def installed(self): + return self._path_hook is not None + + def install(self): + if self.installed: + return + + loader = self + + class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader): + def exec_module(self, module): + super().exec_module(module) + loader.wire_module(module) + + class SourceFileLoader(importlib.machinery.SourceFileLoader): + def exec_module(self, module): + super().exec_module(module) + loader.wire_module(module) + + loader_details = [ + (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES), + (SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES), + ] + + self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details) + + sys.path_hooks.insert(0, self._path_hook) + sys.path_importer_cache.clear() + importlib.invalidate_caches() + + def uninstall(self): + if not self.installed: + return + + sys.path_hooks.remove(self._path_hook) + sys.path_importer_cache.clear() + importlib.invalidate_caches() + + +_loader = AutoLoader() + + +def register_loader_containers(*containers: Container) -> None: + """Register containers in auto-wiring module loader.""" + _loader.register_containers(*containers) + + +def unregister_loader_containers(*containers: Container) -> None: + """Unregister containers from auto-wiring module loader.""" + _loader.unregister_containers(*containers) + + +def install_loader() -> None: + """Install auto-wiring module loader hook.""" + _loader.install() + + +def uninstall_loader() -> None: + """Uninstall auto-wiring module loader hook.""" + _loader.uninstall() + + +def is_loader_installed() -> bool: + """Check if auto-wiring module loader hook is installed.""" + return _loader.installed diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 84938a6a..25f4b9f2 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -1,7 +1,15 @@ +import contextlib from decimal import Decimal +import importlib import unittest -from dependency_injector.wiring import wire, Provide, Closing +from dependency_injector.wiring import ( + wire, + Provide, + Closing, + register_loader_containers, + unregister_loader_containers, +) from dependency_injector import errors # Runtime import to avoid syntax errors in samples on Python < 3.5 @@ -367,3 +375,29 @@ class WiringAsyncInjectionsTest(AsyncTestCase): self.assertIs(resource2, asyncinjections.resource2) self.assertEqual(asyncinjections.resource2.init_counter, 2) self.assertEqual(asyncinjections.resource2.shutdown_counter, 2) + + +class AutoLoaderTest(unittest.TestCase): + + container: Container + + def setUp(self) -> None: + self.container = Container(config={'a': {'b': {'c': 10}}}) + importlib.reload(module) + + def tearDown(self) -> None: + with contextlib.suppress(ValueError): + unregister_loader_containers(self.container) + + self.container.unwire() + + @classmethod + def tearDownClass(cls) -> None: + importlib.reload(module) + + def test_register_container(self): + register_loader_containers(self.container) + importlib.reload(module) + + service = module.test_function() + self.assertIsInstance(service, Service)