From 7684bc7dcb96aa9ad50655b1a1cc80ef449be500 Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Wed, 11 Nov 2015 15:45:39 +0200 Subject: [PATCH] Refactor DynamicCatalog --- dependency_injector/catalogs.py | 66 +++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/dependency_injector/catalogs.py b/dependency_injector/catalogs.py index 75ac72f7..2796fe46 100644 --- a/dependency_injector/catalogs.py +++ b/dependency_injector/catalogs.py @@ -69,32 +69,24 @@ class DynamicCatalog(object): """Catalog of providers.""" __IS_CATALOG__ = True - __slots__ = ('Bundle', 'name', 'providers', 'provider_names', - 'overridden_by') + __slots__ = ('name', 'providers', 'provider_names', 'overridden_by', + 'Bundle') def __init__(self, **providers): """Initializer. - :param name: Catalog's name - :type name: str - - :type kwargs: dict[str, dependency_injector.providers.Provider] + :type providers: dict[str, dependency_injector.providers.Provider] """ - self.Bundle = CatalogBundle.sub_cls_factory(self) self.name = '.'.join((self.__class__.__module__, self.__class__.__name__)) self.providers = dict() self.provider_names = dict() - for name, provider in six.iteritems(providers): - provider = ensure_is_provider(provider) - if provider in self.provider_names: - raise Error('Provider {0} could not be bound to the same ' - 'catalog (or catalogs hierarchy) more ' - 'than once'.format(provider)) - self.provider_names[provider] = name - self.providers[name] = provider self.overridden_by = tuple() + self.Bundle = CatalogBundle.sub_cls_factory(self) + + self.bind_providers(providers) + def is_bundle_owner(self, bundle): """Check if catalog is bundle owner.""" return ensure_is_catalog_bundle(bundle) and bundle.catalog is self @@ -136,7 +128,7 @@ class DynamicCatalog(object): """ self.overridden_by += (overriding,) for name, provider in six.iteritems(overriding.providers): - self.get(name).override(provider) + self.get_provider(name).override(provider) def reset_last_overriding(self): """Reset last overriding catalog.""" @@ -152,7 +144,7 @@ class DynamicCatalog(object): for provider in six.itervalues(self.providers): provider.reset_override() - def get(self, name): + def get_provider(self, name): """Return provider with specified name or raise an error.""" try: return self.providers[name] @@ -160,13 +152,37 @@ class DynamicCatalog(object): raise Error('{0} has no provider with such name - {1}'.format( self, name)) - def has(self, name): + def bind_provider(self, name, provider): + """Bind provider to catalog with specified name.""" + provider = ensure_is_provider(provider) + + if name in self.providers: + raise Error('Catalog {0} already has provider with ' + 'such name - {1}'.format(self, name)) + if provider in self.provider_names: + raise Error('Catalog {0} already has such provider ' + 'instance - {1}'.format(self, provider)) + + self.providers[name] = provider + self.provider_names[provider] = name + + def bind_providers(self, providers): + """Bind providers dictionary to catalog.""" + for name, provider in six.iteritems(providers): + self.bind_provider(name, provider) + + def has_provider(self, name): """Check if there is provider with certain name.""" return name in self.providers + def __getattr__(self, name): + """Return provider with specified name or raise en error.""" + return self.get_provider(name) + def __repr__(self): """Return Python representation of catalog.""" - return ''.format(self.name) + return '<{0}({1})>'.format(self.name, + ', '.join(six.iterkeys(self.providers))) __str__ = __repr__ @@ -191,12 +207,16 @@ class DeclarativeCatalogMetaClass(type): providers = cls_providers + inherited_providers cls.name = '.'.join((cls.__module__, cls.__name__)) - cls.catalog = DynamicCatalog(**dict(providers)) + + cls.catalog = DynamicCatalog() cls.catalog.name = cls.name - cls.Bundle = cls.catalog.Bundle + cls.catalog.bind_providers(dict(providers)) + cls.cls_providers = dict(cls_providers) cls.inherited_providers = dict(inherited_providers) + cls.Bundle = cls.catalog.Bundle + return cls @property @@ -316,12 +336,12 @@ class DeclarativeCatalog(object): @classmethod def get(cls, name): """Return provider with specified name or raises error.""" - return cls.catalog.get(name) + return cls.catalog.get_provider(name) @classmethod def has(cls, name): """Check if there is provider with certain name.""" - return cls.catalog.has(name) + return cls.catalog.has_provider(name) # Backward compatibility for versions < 0.11.*