diff --git a/README.rst b/README.rst index 7dbbc82..6bdd643 100644 --- a/README.rst +++ b/README.rst @@ -115,6 +115,7 @@ A means of providing an instance of a type. Built-in providers include ``ClassProvider`` (creates a new instance from a class), ``InstanceProvider`` (returns an existing instance directly) and ``CallableProvider`` (provides an instance by calling a function). +``AssistedFactoryProvider`` (provides a factory which can be used for assisted injection) Scope ----- @@ -267,6 +268,44 @@ Or transitively:: >>> user.description 'Sherlock is a man of astounding insight' +Assisted injection +------------------ +Sometimes there are classes that have injectable and non-injectable parameters in their +constructors. Let's have for example:: + + >>> Database = Key('Database') + >>> class User(object): + ... def __init__(self, name): + ... self.name = name + + >>> class UserUpdater(object): + ... @inject(db = Database) + ... def __init__(self, db, user): + ... pass + + You may want to have database connection ``db`` injected into ``UserUpdater`` constructor, + but in the same time provide ``user`` object by yourself, and assuming that ``user`` object + is a value object and there's many users in your application it doesn't make much sense + to inject objects of class ``User``. + + In this situation there's technique called Assisted injection:: + + >>> UserUpdaterFactory = Key('UserUpdaterFactory') + >>> def module(binder): + ... binder.bind(UserUpdaterFactory, to=AssistedFactoryProvider(UserUpdater)) + + ... injector = Injector(module) + ... factory = injector.get(UserUpdaterFactory) + ... user = User('John') + ... user_updater = factory.create(user=user) + + This way we don't make ``UserUpdater`` directly injectable - we provide injectable factory. + Such factory has ``create(**kwargs)`` method which takes non-injectable parameters, combines + them with injectable dependencies of ``UserUpdater`` and calls ``UserUpdater`` initializer + using all of them. + + More information on this topic: `"How to use Google Guice to create objects that require parameters?" on Stack Overflow `_ + Scopes ====== diff --git a/injector.py b/injector.py index c75fcd3..c75f158 100644 --- a/injector.py +++ b/injector.py @@ -124,6 +124,23 @@ class MapBindProvider(ListOfProviders): map.update(provider.get()) return map +class AssistedFactoryProvider(Provider): + + class AssistedFactory(object): + def __init__(self, injector, cls): + self._injector = injector + self._cls = cls + + def create(self, **kwargs): + return self._injector.create_object(self._cls, additional_kwargs=kwargs) + + _injector = None + + def __init__(self, cls): + self._cls = cls + + def get(self): + return self.AssistedFactory(self._injector, self._cls) # These classes are used internally by the Binder. class BindingKey(tuple): @@ -234,6 +251,8 @@ class Binder(object): def provider_for(self, interface, to=None): if isinstance(to, Provider): + if isinstance(to, AssistedFactoryProvider): + to._injector = self.injector return to elif isinstance(to, (types.FunctionType, types.LambdaType, types.MethodType, types.BuiltinFunctionType, @@ -449,15 +468,17 @@ class Injector(object): 'with Binder.bind_scope(scope_cls)' % e) return scope_instance.get(key, binding.provider).get() - def create_object(self, cls): + def create_object(self, cls, additional_kwargs=None): """Create a new instance, satisfying any dependencies on cls.""" + + additional_kwargs = additional_kwargs or {} instance = cls.__new__(cls) try: self.install_into(instance) except AttributeError: # Some builtin types can not be modified. pass - instance.__init__() + instance.__init__(**additional_kwargs) return instance def install_into(self, instance): diff --git a/injector_test.py b/injector_test.py index a2a54cd..a37c835 100644 --- a/injector_test.py +++ b/injector_test.py @@ -16,7 +16,7 @@ import threading import pytest -from injector import (Binder, Injector, Scope, InstanceProvider, ClassProvider, +from injector import (AssistedFactoryProvider, Binder, Injector, Scope, InstanceProvider, ClassProvider, inject, singleton, threadlocal, UnsatisfiedRequirement, CircularDependency, Module, provides, Key, extends, SingletonScope, ScopeDecorator, with_injector) @@ -592,3 +592,20 @@ def test_binder_provider_for_type_with_metaclass(): binder = Injector().binder assert (isinstance(binder.provider_for(A, None).get(), A)) + +def test_assisted_factory_provider_works(): + class A(object): + @inject(aaa=str) + def __init__(self, aaa, bbb): + self.aaa = aaa + self.bbb = bbb + + AFactory = Key('AFactory') + def conf(binder): + binder.bind(AFactory, to=AssistedFactoryProvider(A)) + + injector = Injector(conf) + factory = injector.get(AFactory) + a = factory.create(bbb=123) + assert (a.aaa == str()) + assert (a.bbb == 123)