diff --git a/dependency_injector/catalogs.py b/dependency_injector/catalogs.py index 41a5b5ea..75af345d 100644 --- a/dependency_injector/catalogs.py +++ b/dependency_injector/catalogs.py @@ -166,8 +166,18 @@ class DynamicCatalog(object): :type: tuple[ :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog`] + + .. py:attribute:: provider_type + + If provider type is defined, :py:class:`DynamicCatalog` checks that + all of its providers are instances of + :py:attr:`DynamicCatalog.provider_type`. + + :type: type | None """ + provider_type = None + __IS_CATALOG__ = True __slots__ = ('name', 'providers', 'provider_names', 'overridden_by', 'Bundle') @@ -324,6 +334,11 @@ class DynamicCatalog(object): """ provider = ensure_is_provider(provider) + if (self.__class__.provider_type and + not isinstance(provider, self.__class__.provider_type)): + raise Error('{0} can contaon only {1} instances'.format( + self.name, self.__class__.provider_type)) + if name in self.providers: raise Error('Catalog {0} already has provider with ' 'such name - {1}'.format(self, name)) @@ -443,7 +458,13 @@ class DeclarativeCatalogMetaClass(type): cls = type.__new__(mcs, class_name, bases, attributes) - cls._catalog = DynamicCatalog() + if cls.provider_type: + cls._catalog = type('DynamicCatalog', + (DynamicCatalog,), + dict(provider_type=cls.provider_type))() + else: + cls._catalog = DynamicCatalog() + cls._catalog.name = '.'.join((cls.__module__, cls.__name__)) cls._catalog.bind_providers(dict(providers)) @@ -625,6 +646,14 @@ class DeclarativeCatalog(object): :type: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog` | None + + .. py:attribute:: provider_type + + If provider type is defined, :py:class:`DeclarativeCatalog` checks that + all of its providers are instances of + :py:attr:`DeclarativeCatalog.provider_type`. + + :type: type | None """ Bundle = CatalogBundle @@ -639,6 +668,8 @@ class DeclarativeCatalog(object): is_overridden = bool last_overriding = None + provider_type = None + _catalog = DynamicCatalog __IS_CATALOG__ = True diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 3ba48e20..740365c3 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -131,6 +131,51 @@ class DynamicCatalogTests(unittest.TestCase): self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.get_provider('py'), py) + def test_bind_provider_with_valid_provided_type(self): + """Test setting of provider with provider type restriction.""" + class SomeProvider(providers.Provider): + """Some provider.""" + + class SomeCatalog(catalogs.DynamicCatalog): + """Some catalog with provider type restriction.""" + + provider_type = SomeProvider + + px = SomeProvider() + py = SomeProvider() + catalog = SomeCatalog() + + catalog.bind_provider('px', px) + catalog.py = py + + self.assertIs(catalog.px, px) + self.assertIs(catalog.get_provider('px'), px) + + self.assertIs(catalog.py, py) + self.assertIs(catalog.get_provider('py'), py) + + def test_bind_provider_with_invalid_provided_type(self): + """Test setting of provider with provider type restriction.""" + class SomeProvider(providers.Provider): + """Some provider.""" + + class SomeCatalog(catalogs.DynamicCatalog): + """Some catalog with provider type restriction.""" + + provider_type = SomeProvider + + px = providers.Provider() + catalog = SomeCatalog() + + with self.assertRaises(errors.Error): + catalog.bind_provider('px', px) + + with self.assertRaises(errors.Error): + catalog.px = px + + with self.assertRaises(errors.Error): + catalog.bind_providers(dict(px=px)) + def test_bind_providers(self): """Test setting of provider via bind_providers() to catalog.""" px = providers.Provider() @@ -289,6 +334,49 @@ class DeclarativeCatalogTests(unittest.TestCase): del CatalogA.px del CatalogA.py + def test_bind_provider_with_valid_provided_type(self): + """Test setting of provider with provider type restriction.""" + class SomeProvider(providers.Provider): + """Some provider.""" + + class SomeCatalog(catalogs.DeclarativeCatalog): + """Some catalog with provider type restriction.""" + + provider_type = SomeProvider + + px = SomeProvider() + py = SomeProvider() + + SomeCatalog.bind_provider('px', px) + SomeCatalog.py = py + + self.assertIs(SomeCatalog.px, px) + self.assertIs(SomeCatalog.get_provider('px'), px) + + self.assertIs(SomeCatalog.py, py) + self.assertIs(SomeCatalog.get_provider('py'), py) + + def test_bind_provider_with_invalid_provided_type(self): + """Test setting of provider with provider type restriction.""" + class SomeProvider(providers.Provider): + """Some provider.""" + + class SomeCatalog(catalogs.DeclarativeCatalog): + """Some catalog with provider type restriction.""" + + provider_type = SomeProvider + + px = providers.Provider() + + with self.assertRaises(errors.Error): + SomeCatalog.bind_provider('px', px) + + with self.assertRaises(errors.Error): + SomeCatalog.px = px + + with self.assertRaises(errors.Error): + SomeCatalog.bind_providers(dict(px=px)) + def test_bind_providers(self): """Test setting of provider via bind_providers() to catalog.""" px = providers.Provider()