From d602b5520e9a5a19479dbff5faef6f7401ba8acc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Medrano=20Llamas?= Date: Wed, 22 Mar 2017 22:20:49 +0100 Subject: [PATCH] Fix handling of issubclass() with typing.Union (#61) --- COPYING | 2 +- injector.py | 19 ++++++++++++------- injector_test.py | 22 ++++++++++++++++++++++ injector_test_py3.py | 24 +++++++++++++++++++++++- 4 files changed, 58 insertions(+), 9 deletions(-) diff --git a/COPYING b/COPYING index a19b705..4c1c065 100644 --- a/COPYING +++ b/COPYING @@ -1,4 +1,4 @@ -Copyright (c) 2010, Alec Thomas +Copyright (c) 2010, Alec Thomas, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/injector.py b/injector.py index a7f3d1d..a432a05 100644 --- a/injector.py +++ b/injector.py @@ -482,13 +482,18 @@ if TYPING353: # issubclass(SomeGeneric[X], SomeGeneric) so we need some other way to # determine whether a particular object is a generic class with type parameters # provided. Fortunately there seems to be __origin__ attribute that's useful here. - return ( - hasattr(cls, '__origin__') and - # __origin__ is generic_class is a special case to handle Union as - # Union cannot be used in issubclass() check (it raises an exception - # by design). - (cls.__origin__ is generic_class or issubclass(cls.__origin__, generic_class)) - ) + if not hasattr(cls, '__origin__'): + return False + origin = cls.__origin__ + if not inspect.isclass(generic_class): + generic_class = type(generic_class) + if not inspect.isclass(origin): + origin = type(origin) + # __origin__ is generic_class is a special case to handle Union as + # Union cannot be used in issubclass() check (it raises an exception + # by design). + return origin is generic_class or issubclass(origin, generic_class) + else: # To maintain compatibility we fall back to an issubclass check. def _is_specialization(cls, generic_class): diff --git a/injector_test.py b/injector_test.py index dc29725..3e83856 100644 --- a/injector_test.py +++ b/injector_test.py @@ -1014,3 +1014,25 @@ def test_binding_an_instance_regression(): injector = Injector(configure) # This used to return empty bytes instead of the expected string assert injector.get(bytes) == text + + +def test_class_assisted_builder_of_partially_injected_class(): + class A(object): + pass + + class B(object): + @inject(a=A, b=str) + def __init__(self, a, b): + self.a = a + self.b = b + + class C(object): + @inject(a=A, builder=ClassAssistedBuilder[B]) + def __init__(self, a, builder): + self.a = a + self.b = builder.build(b='C') + + c = Injector().get(C) + assert isinstance(c, C) + assert isinstance(c.b, B) + assert isinstance(c.b.a, A) diff --git a/injector_test_py3.py b/injector_test_py3.py index cc962e9..9379437 100644 --- a/injector_test_py3.py +++ b/injector_test_py3.py @@ -3,7 +3,7 @@ from typing import Any import pytest from injector import ( - AssistedBuilder, inject, Injector, CallError, + AssistedBuilder, ClassAssistedBuilder, inject, Injector, CallError, Module, noninjectable, provider, provides, singleton, ) @@ -264,3 +264,25 @@ def test_optionals_are_ignored_for_now(): return s assert Injector().call_with_injection(fun) == '' + + +def test_class_assisted_builder_of_partially_injected_class(): + class A(object): + pass + + class B(object): + @inject + def __init__(self, a: A, b: str): + self.a = a + self.b = b + + class C(object): + @inject + def __init__(self, a: A, builder: ClassAssistedBuilder[B]): + self.a = a + self.b = builder.build(b='C') + + c = Injector().get(C) + assert isinstance(c, C) + assert isinstance(c.b, B) + assert isinstance(c.b.a, A)