diff --git a/README.rst b/README.rst index a463bbe..19472ad 100644 --- a/README.rst +++ b/README.rst @@ -32,10 +32,8 @@ We'll use an in-memory SQLite database for our example:: And make up an imaginary RequestHandler class that uses the SQLite connection:: - >>> class RequestHandler(object): - ... @inject(db=sqlite3.Connection) - ... def __init__(self, db): - ... self._db = db + >>> @inject(_db=sqlite3.Connection) + ... class RequestHandler(object): ... def get(self): ... cursor = self._db.cursor() ... cursor.execute('SELECT key, value FROM data ORDER by key') @@ -236,6 +234,20 @@ constructor of a normal class:: ... def describe(self, name): ... return '%s is a man of astounding insight' % name +You can also ``inject``-decorate class itself. This code:: + + >>> @inject(name=Name) + ... class Item(object): + ... pass + +is equivalent to:: + + >>> class Item(object): + ... @inject(name=Name) + ... def __init__(self, name): + ... self.name = name + + Injector -------- The ``Injector`` brings everything together. It takes a list of @@ -278,9 +290,9 @@ constructors. Let's have for example:: ... def __init__(self, name): ... self.name = name + >>> @inject(db=Database) >>> class UserUpdater(object): - ... @inject(db = Database) - ... def __init__(self, db, user): + ... def __init__(self, user): ... pass You may want to have database connection ``db`` injected into ``UserUpdater`` constructor, @@ -303,10 +315,10 @@ using all of them. ``AssistedBuilder(X)`` is injectable just as anything else, if you need instance of it you just ask for it like that:: - >>> class NeedsUserUpdater(object): - ... @inject(builder=AssistedBuilder(UserUpdater)) - ... def method(self, builder): - ... pass + >>> @inject(updater_builder=AssistedBuilder(UserUpdater)) + ... class NeedsUserUpdater(object): + ... def method(self): + ... updater = self.updater_builder.build(user=None) More information on this topic: diff --git a/injector.py b/injector.py index 3abf4e7..67780bf 100644 --- a/injector.py +++ b/injector.py @@ -641,7 +641,7 @@ def inject(**bindings): 'Get my friends' """ - def wrapper(f): + def method_wrapper(f): for key, value in bindings.items(): bindings[key] = BindingKey(value, None) if hasattr(inspect, 'getfullargspec'): @@ -674,7 +674,24 @@ def inject(**bindings): inject = f inject.__bindings__ = bindings return inject - return wrapper + + def class_wrapper(cls): + orig_init = cls.__init__ + @inject(**bindings) + def init(self, *args, **kwargs): + for key in bindings: + setattr(self, key, kwargs.pop(key)) + orig_init(self, *args, **kwargs) + cls.__init__ = init + return cls + + def multi_wrapper(something): + if type(something) is type: + return class_wrapper(something) + else: + return method_wrapper(something) + + return multi_wrapper class BaseAnnotation(object): diff --git a/injector_test.py b/injector_test.py index 7a16814..f9033e4 100644 --- a/injector_test.py +++ b/injector_test.py @@ -735,3 +735,66 @@ class TestThreadSafety(object): self.injector.binder.bind(self.cls, scope=singleton) a, b = self.gather_results(2) assert (a is b) + +class TestClassInjection(object): + def setup(self): + class A(object): + counter = 0 + + def __init__(self): + A.counter += 1 + + @inject(a=A) + class B(object): + pass + + @inject(a=A) + class C(object): + def __init__(self, noninjectable): + self.noninjectable = noninjectable + + self.injector = Injector() + self.A = A + self.B = B + self.C = C + + def test_instantiation_still_requires_parameters(self): + for cls in (self.B, self.C): + with pytest.raises(Exception): + obj = cls() + + with pytest.raises(Exception): + c = self.C(noninjectable=1) + + with pytest.raises(Exception): + c = self.C(a=self.A()) + + def test_injection_works(self): + b = self.injector.get(self.B) + a = b.a + assert (type(a) == self.A) + + def test_assisted_injection_works(self): + builder = self.injector.get(AssistedBuilder(self.C)) + c = builder.build(noninjectable=5) + + assert((type(c.a), c.noninjectable) == (self.A, 5)) + + def test_members_are_injected_only_once(self): + b = self.injector.get(self.B) + _1 = b.a + _2 = b.a + assert (self.A.counter == 1 and _1 is _2) + + def test_each_instance_gets_new_injection(self): + count = 3 + objs = [self.injector.get(self.B).a for i in range(count)] + + assert (self.A.counter == count) + assert (len(set(objs)) == count) + + def test_members_can_be_overwritten(self): + b = self.injector.get(self.B) + b.a = 123 + + assert (b.a == 123)