Make inject decorator work also on classes and save some keystrokes
This commit is contained in:
parent
afeae01483
commit
f0d69d8f8b
32
README.rst
32
README.rst
|
@ -32,10 +32,8 @@ We'll use an in-memory SQLite database for our example::
|
|||
|
||||
And make up an imaginary RequestHandler class that uses the SQLite connection::
|
||||
|
||||
>>> class RequestHandler(object):
|
||||
... @inject(db=sqlite3.Connection)
|
||||
... def __init__(self, db):
|
||||
... self._db = db
|
||||
>>> @inject(_db=sqlite3.Connection)
|
||||
... class RequestHandler(object):
|
||||
... def get(self):
|
||||
... cursor = self._db.cursor()
|
||||
... cursor.execute('SELECT key, value FROM data ORDER by key')
|
||||
|
@ -236,6 +234,20 @@ constructor of a normal class::
|
|||
... def describe(self, name):
|
||||
... return '%s is a man of astounding insight' % name
|
||||
|
||||
You can also ``inject``-decorate class itself. This code::
|
||||
|
||||
>>> @inject(name=Name)
|
||||
... class Item(object):
|
||||
... pass
|
||||
|
||||
is equivalent to::
|
||||
|
||||
>>> class Item(object):
|
||||
... @inject(name=Name)
|
||||
... def __init__(self, name):
|
||||
... self.name = name
|
||||
|
||||
|
||||
Injector
|
||||
--------
|
||||
The ``Injector`` brings everything together. It takes a list of
|
||||
|
@ -278,9 +290,9 @@ constructors. Let's have for example::
|
|||
... def __init__(self, name):
|
||||
... self.name = name
|
||||
|
||||
>>> @inject(db=Database)
|
||||
>>> class UserUpdater(object):
|
||||
... @inject(db = Database)
|
||||
... def __init__(self, db, user):
|
||||
... def __init__(self, user):
|
||||
... pass
|
||||
|
||||
You may want to have database connection ``db`` injected into ``UserUpdater`` constructor,
|
||||
|
@ -303,10 +315,10 @@ using all of them.
|
|||
``AssistedBuilder(X)`` is injectable just as anything else, if you need instance of it you
|
||||
just ask for it like that::
|
||||
|
||||
>>> class NeedsUserUpdater(object):
|
||||
... @inject(builder=AssistedBuilder(UserUpdater))
|
||||
... def method(self, builder):
|
||||
... pass
|
||||
>>> @inject(updater_builder=AssistedBuilder(UserUpdater))
|
||||
... class NeedsUserUpdater(object):
|
||||
... def method(self):
|
||||
... updater = self.updater_builder.build(user=None)
|
||||
|
||||
More information on this topic:
|
||||
|
||||
|
|
21
injector.py
21
injector.py
|
@ -641,7 +641,7 @@ def inject(**bindings):
|
|||
'Get my friends'
|
||||
"""
|
||||
|
||||
def wrapper(f):
|
||||
def method_wrapper(f):
|
||||
for key, value in bindings.items():
|
||||
bindings[key] = BindingKey(value, None)
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
|
@ -674,7 +674,24 @@ def inject(**bindings):
|
|||
inject = f
|
||||
inject.__bindings__ = bindings
|
||||
return inject
|
||||
return wrapper
|
||||
|
||||
def class_wrapper(cls):
|
||||
orig_init = cls.__init__
|
||||
@inject(**bindings)
|
||||
def init(self, *args, **kwargs):
|
||||
for key in bindings:
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
orig_init(self, *args, **kwargs)
|
||||
cls.__init__ = init
|
||||
return cls
|
||||
|
||||
def multi_wrapper(something):
|
||||
if type(something) is type:
|
||||
return class_wrapper(something)
|
||||
else:
|
||||
return method_wrapper(something)
|
||||
|
||||
return multi_wrapper
|
||||
|
||||
|
||||
class BaseAnnotation(object):
|
||||
|
|
|
@ -735,3 +735,66 @@ class TestThreadSafety(object):
|
|||
self.injector.binder.bind(self.cls, scope=singleton)
|
||||
a, b = self.gather_results(2)
|
||||
assert (a is b)
|
||||
|
||||
class TestClassInjection(object):
|
||||
def setup(self):
|
||||
class A(object):
|
||||
counter = 0
|
||||
|
||||
def __init__(self):
|
||||
A.counter += 1
|
||||
|
||||
@inject(a=A)
|
||||
class B(object):
|
||||
pass
|
||||
|
||||
@inject(a=A)
|
||||
class C(object):
|
||||
def __init__(self, noninjectable):
|
||||
self.noninjectable = noninjectable
|
||||
|
||||
self.injector = Injector()
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
|
||||
def test_instantiation_still_requires_parameters(self):
|
||||
for cls in (self.B, self.C):
|
||||
with pytest.raises(Exception):
|
||||
obj = cls()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
c = self.C(noninjectable=1)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
c = self.C(a=self.A())
|
||||
|
||||
def test_injection_works(self):
|
||||
b = self.injector.get(self.B)
|
||||
a = b.a
|
||||
assert (type(a) == self.A)
|
||||
|
||||
def test_assisted_injection_works(self):
|
||||
builder = self.injector.get(AssistedBuilder(self.C))
|
||||
c = builder.build(noninjectable=5)
|
||||
|
||||
assert((type(c.a), c.noninjectable) == (self.A, 5))
|
||||
|
||||
def test_members_are_injected_only_once(self):
|
||||
b = self.injector.get(self.B)
|
||||
_1 = b.a
|
||||
_2 = b.a
|
||||
assert (self.A.counter == 1 and _1 is _2)
|
||||
|
||||
def test_each_instance_gets_new_injection(self):
|
||||
count = 3
|
||||
objs = [self.injector.get(self.B).a for i in range(count)]
|
||||
|
||||
assert (self.A.counter == count)
|
||||
assert (len(set(objs)) == count)
|
||||
|
||||
def test_members_can_be_overwritten(self):
|
||||
b = self.injector.get(self.B)
|
||||
b.a = 123
|
||||
|
||||
assert (b.a == 123)
|
||||
|
|
Loading…
Reference in New Issue