From d51f2fa79abfdd6335dfad8da92b34e8b5a03d88 Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Mon, 14 Mar 2016 00:08:41 +0200 Subject: [PATCH] Add returning of overriding provider in provider overriding context --- dependency_injector/providers.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index f2b1c14a..247043e1 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -5,6 +5,7 @@ import six from .injections import _parse_args_injections from .injections import _parse_kwargs_injections +from .utils import is_provider from .utils import ensure_is_provider from .utils import is_attribute_injection from .utils import is_method_injection @@ -117,6 +118,10 @@ class Provider(object): if provider is self: raise Error('Provider {0} could not be overridden ' 'with itself'.format(self)) + + if not is_provider(provider): + provider = Object(provider) + if not self.is_overridden: self.overridden_by = (ensure_is_provider(provider),) else: @@ -126,7 +131,7 @@ class Provider(object): if self.__class__.__OPTIMIZED_CALLS__: self.__call__ = self.provide = self._call_last_overriding - return OverridingContext(self) + return OverridingContext(self, provider) def reset_last_overriding(self): """Reset last overriding provider. @@ -1056,16 +1061,21 @@ class OverridingContext(object): assert not provider.is_overridden """ - def __init__(self, overridden): + def __init__(self, overridden, overriding): """Initializer. - :param overridden: Overridden provider + :param overridden: Overridden provider. :type overridden: :py:class:`Provider` + + :param overriding: Overriding provider. + :type overriding: :py:class:`Provider` """ self.overridden = overridden + self.overriding = overriding def __enter__(self): """Do nothing.""" + return self.overriding def __exit__(self, *_): """Exit overriding context."""