diff --git a/README.rst b/README.rst index 439a2b1..9fa700b 100644 --- a/README.rst +++ b/README.rst @@ -410,6 +410,14 @@ class instance on the time of method call:: After such call all ``inject``-decorated methods will work just as you'd expect them to work. +Thread safety +============= + +The following functions are thread safe: + +* ``Injector.get`` +* injection provided by ``inject`` decorator (please note, however, that it doesn't say anything about decorated function thread safety) + Footnote ======== This framework is similar to snake-guice, but aims for simplification. diff --git a/injector.py b/injector.py index fe961f4..f2d142a 100644 --- a/injector.py +++ b/injector.py @@ -29,6 +29,17 @@ __author__ = 'Alec Thomas ' __version__ = '0.5.2' __version_tag__ = '' +def synchronized(lock): + def outside_wrapper(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + with lock: + return function(*args, **kwargs) + return wrapper + return outside_wrapper + +lock = threading.RLock() + class Error(Exception): """Base exception.""" @@ -353,6 +364,7 @@ class SingletonScope(Scope): def configure(self): self._context = {} + @synchronized(lock) def get(self, key, provider): try: return self._context[key] @@ -493,6 +505,7 @@ class Injector(object): """ instance.__injector__ = self + @synchronized(lock) def args_to_inject(self, function, bindings, owner_key): """Inject arguments into a function. diff --git a/injector_test.py b/injector_test.py index 177c432..e770f33 100644 --- a/injector_test.py +++ b/injector_test.py @@ -11,6 +11,7 @@ """Functional tests for the Pollute dependency injection framework.""" from contextlib import contextmanager +from time import sleep import abc import threading @@ -664,3 +665,44 @@ def test_assisted_builder_works_when_injected(): injector = Injector() x = injector.get(X) assert ((x.obj.a, x.obj.b) == (str(), 234)) + +class TestThreadSafety(object): + def setup(self): + def configure(binder): + binder.bind(str, to=lambda: sleep(1) and 'this is str') + + class XXX(object): + @inject(s=str) + def __init__(self, s): + pass + + self.injector = Injector(configure) + self.cls = XXX + + def gather_results(self, count): + objects = [] + lock = threading.Lock() + + def target(): + o = self.injector.get(self.cls) + with lock: + objects.append(o) + + threads = [threading.Thread(target=target) for i in range(2)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + return objects + + def test_injection_is_thread_safe(self): + objects = self.gather_results(2) + assert (len(objects) == 2) + + def test_singleton_scope_is_thread_safe(self): + self.injector.binder.bind(self.cls, scope=singleton) + a, b = self.gather_results(2) + assert (a is b)