diff --git a/injector/__init__.py b/injector/__init__.py index 0add517..fc06425 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -863,14 +863,43 @@ class Injector: return dependencies -def get_bindings(callable): +def get_bindings(callable: Callable) -> Dict[str, type]: + """Get bindings of injectable parameters from a callable. + + If the callable is not decorated with :func:`inject` an empty dictionary will + be returned. Otherwise the returned dictionary will contain a mapping + between parameter names and their types with the exception of parameters + excluded from dependency injection with :func:`noninjectable`. For example:: + + >>> def function1(a: int) -> None: + ... pass + ... + >>> get_bindings(function1) + {} + >>> @inject + ... def function2(a: int) -> None: + ... pass + ... + >>> get_bindings(function2) + {'a': int} + >>> @inject + ... @noninjectable('b') + ... def function3(a: int, b: str) -> None: + ... pass + ... + >>> get_bindings(function3) + {'a': int} + + This function is used internally so by calling it you can learn what exactly + Injector is going to try to provide to a callable. + """ if not hasattr(callable, '__bindings__'): return {} - if callable.__bindings__ == 'deferred': + if cast(Any, callable).__bindings__ == 'deferred': read_and_store_bindings(callable, _infer_injected_bindings(callable)) noninjectables = getattr(callable, '__noninjectables__', set()) - return {k: v for k, v in callable.__bindings__.items() if k not in noninjectables} + return {k: v for k, v in cast(Any, callable).__bindings__.items() if k not in noninjectables} class _BindingNotYetAvailable(Exception): diff --git a/injector_test.py b/injector_test.py index c74b766..e305c18 100644 --- a/injector_test.py +++ b/injector_test.py @@ -29,6 +29,7 @@ from injector import ( Scope, InstanceProvider, ClassProvider, + get_bindings, inject, multiprovider, noninjectable, @@ -1417,3 +1418,23 @@ class Data: injector = Injector([configure]) assert injector.get(Data).name == 'data' + + +def test_get_bindings(): + def function1(a: int) -> None: + pass + + assert get_bindings(function1) == {} + + @inject + def function2(a: int) -> None: + pass + + assert get_bindings(function2) == {'a': int} + + @inject + @noninjectable('b') + def function3(a: int, b: str) -> None: + pass + + assert get_bindings(function3) == {'a': int}