Fix handling of issubclass() with typing.Union (#61)
This commit is contained in:
parent
bcc5ed0535
commit
d602b5520e
2
COPYING
2
COPYING
|
@ -1,4 +1,4 @@
|
|||
Copyright (c) 2010, Alec Thomas
|
||||
Copyright (c) 2010, Alec Thomas, Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
|
|
19
injector.py
19
injector.py
|
@ -482,13 +482,18 @@ if TYPING353:
|
|||
# issubclass(SomeGeneric[X], SomeGeneric) so we need some other way to
|
||||
# determine whether a particular object is a generic class with type parameters
|
||||
# provided. Fortunately there seems to be __origin__ attribute that's useful here.
|
||||
return (
|
||||
hasattr(cls, '__origin__') and
|
||||
# __origin__ is generic_class is a special case to handle Union as
|
||||
# Union cannot be used in issubclass() check (it raises an exception
|
||||
# by design).
|
||||
(cls.__origin__ is generic_class or issubclass(cls.__origin__, generic_class))
|
||||
)
|
||||
if not hasattr(cls, '__origin__'):
|
||||
return False
|
||||
origin = cls.__origin__
|
||||
if not inspect.isclass(generic_class):
|
||||
generic_class = type(generic_class)
|
||||
if not inspect.isclass(origin):
|
||||
origin = type(origin)
|
||||
# __origin__ is generic_class is a special case to handle Union as
|
||||
# Union cannot be used in issubclass() check (it raises an exception
|
||||
# by design).
|
||||
return origin is generic_class or issubclass(origin, generic_class)
|
||||
|
||||
else:
|
||||
# To maintain compatibility we fall back to an issubclass check.
|
||||
def _is_specialization(cls, generic_class):
|
||||
|
|
|
@ -1014,3 +1014,25 @@ def test_binding_an_instance_regression():
|
|||
injector = Injector(configure)
|
||||
# This used to return empty bytes instead of the expected string
|
||||
assert injector.get(bytes) == text
|
||||
|
||||
|
||||
def test_class_assisted_builder_of_partially_injected_class():
|
||||
class A(object):
|
||||
pass
|
||||
|
||||
class B(object):
|
||||
@inject(a=A, b=str)
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
class C(object):
|
||||
@inject(a=A, builder=ClassAssistedBuilder[B])
|
||||
def __init__(self, a, builder):
|
||||
self.a = a
|
||||
self.b = builder.build(b='C')
|
||||
|
||||
c = Injector().get(C)
|
||||
assert isinstance(c, C)
|
||||
assert isinstance(c.b, B)
|
||||
assert isinstance(c.b.a, A)
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
import pytest
|
||||
|
||||
from injector import (
|
||||
AssistedBuilder, inject, Injector, CallError,
|
||||
AssistedBuilder, ClassAssistedBuilder, inject, Injector, CallError,
|
||||
Module, noninjectable, provider, provides, singleton,
|
||||
)
|
||||
|
||||
|
@ -264,3 +264,25 @@ def test_optionals_are_ignored_for_now():
|
|||
return s
|
||||
|
||||
assert Injector().call_with_injection(fun) == ''
|
||||
|
||||
|
||||
def test_class_assisted_builder_of_partially_injected_class():
|
||||
class A(object):
|
||||
pass
|
||||
|
||||
class B(object):
|
||||
@inject
|
||||
def __init__(self, a: A, b: str):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
class C(object):
|
||||
@inject
|
||||
def __init__(self, a: A, builder: ClassAssistedBuilder[B]):
|
||||
self.a = a
|
||||
self.b = builder.build(b='C')
|
||||
|
||||
c = Injector().get(C)
|
||||
assert isinstance(c, C)
|
||||
assert isinstance(c.b, B)
|
||||
assert isinstance(c.b.a, A)
|
||||
|
|
Loading…
Reference in New Issue