Make inject decorator work also on classes and save some keystrokes

This commit is contained in:
Jakub Stasiak 2013-02-14 00:35:58 +00:00
parent afeae01483
commit f0d69d8f8b
3 changed files with 104 additions and 12 deletions

View File

@ -32,10 +32,8 @@ We'll use an in-memory SQLite database for our example::
And make up an imaginary RequestHandler class that uses the SQLite connection::
>>> class RequestHandler(object):
... @inject(db=sqlite3.Connection)
... def __init__(self, db):
... self._db = db
>>> @inject(_db=sqlite3.Connection)
... class RequestHandler(object):
... def get(self):
... cursor = self._db.cursor()
... cursor.execute('SELECT key, value FROM data ORDER by key')
@ -236,6 +234,20 @@ constructor of a normal class::
... def describe(self, name):
... return '%s is a man of astounding insight' % name
You can also ``inject``-decorate class itself. This code::
>>> @inject(name=Name)
... class Item(object):
... pass
is equivalent to::
>>> class Item(object):
... @inject(name=Name)
... def __init__(self, name):
... self.name = name
Injector
--------
The ``Injector`` brings everything together. It takes a list of
@ -278,9 +290,9 @@ constructors. Let's have for example::
... def __init__(self, name):
... self.name = name
>>> @inject(db=Database)
>>> class UserUpdater(object):
... @inject(db = Database)
... def __init__(self, db, user):
... def __init__(self, user):
... pass
You may want to have database connection ``db`` injected into ``UserUpdater`` constructor,
@ -303,10 +315,10 @@ using all of them.
``AssistedBuilder(X)`` is injectable just as anything else, if you need instance of it you
just ask for it like that::
>>> class NeedsUserUpdater(object):
... @inject(builder=AssistedBuilder(UserUpdater))
... def method(self, builder):
... pass
>>> @inject(updater_builder=AssistedBuilder(UserUpdater))
... class NeedsUserUpdater(object):
... def method(self):
... updater = self.updater_builder.build(user=None)
More information on this topic:

View File

@ -641,7 +641,7 @@ def inject(**bindings):
'Get my friends'
"""
def wrapper(f):
def method_wrapper(f):
for key, value in bindings.items():
bindings[key] = BindingKey(value, None)
if hasattr(inspect, 'getfullargspec'):
@ -674,7 +674,24 @@ def inject(**bindings):
inject = f
inject.__bindings__ = bindings
return inject
return wrapper
def class_wrapper(cls):
orig_init = cls.__init__
@inject(**bindings)
def init(self, *args, **kwargs):
for key in bindings:
setattr(self, key, kwargs.pop(key))
orig_init(self, *args, **kwargs)
cls.__init__ = init
return cls
def multi_wrapper(something):
if type(something) is type:
return class_wrapper(something)
else:
return method_wrapper(something)
return multi_wrapper
class BaseAnnotation(object):

View File

@ -735,3 +735,66 @@ class TestThreadSafety(object):
self.injector.binder.bind(self.cls, scope=singleton)
a, b = self.gather_results(2)
assert (a is b)
class TestClassInjection(object):
def setup(self):
class A(object):
counter = 0
def __init__(self):
A.counter += 1
@inject(a=A)
class B(object):
pass
@inject(a=A)
class C(object):
def __init__(self, noninjectable):
self.noninjectable = noninjectable
self.injector = Injector()
self.A = A
self.B = B
self.C = C
def test_instantiation_still_requires_parameters(self):
for cls in (self.B, self.C):
with pytest.raises(Exception):
obj = cls()
with pytest.raises(Exception):
c = self.C(noninjectable=1)
with pytest.raises(Exception):
c = self.C(a=self.A())
def test_injection_works(self):
b = self.injector.get(self.B)
a = b.a
assert (type(a) == self.A)
def test_assisted_injection_works(self):
builder = self.injector.get(AssistedBuilder(self.C))
c = builder.build(noninjectable=5)
assert((type(c.a), c.noninjectable) == (self.A, 5))
def test_members_are_injected_only_once(self):
b = self.injector.get(self.B)
_1 = b.a
_2 = b.a
assert (self.A.counter == 1 and _1 is _2)
def test_each_instance_gets_new_injection(self):
count = 3
objs = [self.injector.get(self.B).a for i in range(count)]
assert (self.A.counter == count)
assert (len(set(objs)) == count)
def test_members_can_be_overwritten(self):
b = self.injector.get(self.B)
b.a = 123
assert (b.a == 123)