diff --git a/injector/__init__.py b/injector/__init__.py index 731c1c1..a62afbd 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -632,8 +632,9 @@ class Binder: def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']: is_scope = isinstance(interface, type) and issubclass(interface, Scope) + is_assisted_builder = _is_specialization(interface, AssistedBuilder) try: - return self._get_binding(interface, only_this_binder=is_scope) + return self._get_binding(interface, only_this_binder=is_scope or is_assisted_builder) except (KeyError, UnsatisfiedRequirement): if is_scope: scope = interface diff --git a/injector_test.py b/injector_test.py index 76e9bd2..80cca82 100644 --- a/injector_test.py +++ b/injector_test.py @@ -805,6 +805,19 @@ def test_assisted_builder_injection_is_safe_to_use_with_multiple_injectors(): assert (b1._injector, b2._injector) == (i1, i2) +def test_assisted_builder_injection_is_safe_to_use_with_child_injectors(): + class X: + @inject + def __init__(self, builder: AssistedBuilder[NeedsAssistance]): + self.builder = builder + + i1 = Injector() + i2 = i1.create_child_injector() + b1 = i1.get(X).builder + b2 = i2.get(X).builder + assert (b1._injector, b2._injector) == (i1, i2) + + class TestThreadSafety: def setup(self): self.event = threading.Event()