From 9c062bdca541911a99a110be01a650ae473e6135 Mon Sep 17 00:00:00 2001 From: Paul Brown Date: Thu, 30 Dec 2021 06:28:11 +0000 Subject: [PATCH] allow getting recoverable_connection_errors without an active transport (#1471) * allow getting recoverable_connection_errors without an active transport * move redis transport errors to class * move consul transport errors to class * move etcd transport errors to class * remove redis.Transport._get_errors and references in tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix flake8 errors * add integration test for redis ConnectionError Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- kombu/connection.py | 8 +-- kombu/transport/consul.py | 17 +++--- kombu/transport/etcd.py | 17 +++--- kombu/transport/redis.py | 9 ++-- t/integration/test_redis.py | 7 ++- t/unit/test_connection.py | 94 ++++++++++++++++++++++++++++++++-- t/unit/transport/test_redis.py | 20 +++++--- 7 files changed, 135 insertions(+), 37 deletions(-) diff --git a/kombu/connection.py b/kombu/connection.py index e4558e92..4b2cbd62 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -932,7 +932,7 @@ class Connection: but where the connection must be closed and re-established first. """ try: - return self.transport.recoverable_connection_errors + return self.get_transport_cls().recoverable_connection_errors except AttributeError: # There were no such classification before, # and all errors were assumed to be recoverable, @@ -948,19 +948,19 @@ class Connection: recovered from without re-establishing the connection. """ try: - return self.transport.recoverable_channel_errors + return self.get_transport_cls().recoverable_channel_errors except AttributeError: return () @cached_property def connection_errors(self): """List of exceptions that may be raised by the connection.""" - return self.transport.connection_errors + return self.get_transport_cls().connection_errors @cached_property def channel_errors(self): """List of exceptions that may be raised by the channel.""" - return self.transport.channel_errors + return self.get_transport_cls().channel_errors @property def supports_heartbeats(self): diff --git a/kombu/transport/consul.py b/kombu/transport/consul.py index ea275c95..4369b1f8 100644 --- a/kombu/transport/consul.py +++ b/kombu/transport/consul.py @@ -276,24 +276,25 @@ class Transport(virtual.Transport): driver_type = 'consul' driver_name = 'consul' - def __init__(self, *args, **kwargs): - if consul is None: - raise ImportError('Missing python-consul library') - - super().__init__(*args, **kwargs) - - self.connection_errors = ( + if consul: + connection_errors = ( virtual.Transport.connection_errors + ( consul.ConsulException, consul.base.ConsulException ) ) - self.channel_errors = ( + channel_errors = ( virtual.Transport.channel_errors + ( consul.ConsulException, consul.base.ConsulException ) ) + def __init__(self, *args, **kwargs): + if consul is None: + raise ImportError('Missing python-consul library') + + super().__init__(*args, **kwargs) + def verify_connection(self, connection): port = connection.client.port or self.default_port host = connection.client.hostname or DEFAULT_HOST diff --git a/kombu/transport/etcd.py b/kombu/transport/etcd.py index 4d0b0364..ba9e4150 100644 --- a/kombu/transport/etcd.py +++ b/kombu/transport/etcd.py @@ -242,6 +242,15 @@ class Transport(virtual.Transport): implements = virtual.Transport.implements.extend( exchange_type=frozenset(['direct'])) + if etcd: + connection_errors = ( + virtual.Transport.connection_errors + (etcd.EtcdException, ) + ) + + channel_errors = ( + virtual.Transport.channel_errors + (etcd.EtcdException, ) + ) + def __init__(self, *args, **kwargs): """Create a new instance of etcd.Transport.""" if etcd is None: @@ -249,14 +258,6 @@ class Transport(virtual.Transport): super().__init__(*args, **kwargs) - self.connection_errors = ( - virtual.Transport.connection_errors + (etcd.EtcdException, ) - ) - - self.channel_errors = ( - virtual.Transport.channel_errors + (etcd.EtcdException, ) - ) - def verify_connection(self, connection): """Verify the connection works.""" port = connection.client.port or self.default_port diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 090b411d..45558641 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -1214,13 +1214,14 @@ class Transport(virtual.Transport): exchange_type=frozenset(['direct', 'topic', 'fanout']) ) + if redis: + connection_errors, channel_errors = get_redis_error_classes() + def __init__(self, *args, **kwargs): if redis is None: raise ImportError('Missing redis library (pip install redis)') super().__init__(*args, **kwargs) - # Get redis-py exceptions. - self.connection_errors, self.channel_errors = self._get_errors() # All channels share the same poller. self.cycle = MultiChannelPoller() @@ -1265,10 +1266,6 @@ class Transport(virtual.Transport): """Handle AIO event for one of our file descriptors.""" self.cycle.on_readable(fileno) - def _get_errors(self): - """Utility to import redis-py's exceptions at runtime.""" - return get_redis_error_classes() - if sentinel: class SentinelManagedSSLConnection( diff --git a/t/integration/test_redis.py b/t/integration/test_redis.py index 5647bf06..522adf8e 100644 --- a/t/integration/test_redis.py +++ b/t/integration/test_redis.py @@ -5,6 +5,7 @@ import pytest import redis import kombu +from kombu.transport.redis import Transport from .common import (BaseExchangeTypes, BaseMessage, BasePriority, BasicFunctionality) @@ -56,7 +57,11 @@ def test_failed_credentials(): @pytest.mark.env('redis') @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_RedisBasicFunctionality(BasicFunctionality): - pass + def test_failed_connection__ConnectionError(self, invalid_connection): + # method raises transport exception + with pytest.raises(redis.exceptions.ConnectionError) as ex: + invalid_connection.connection + assert ex.type in Transport.connection_errors @pytest.mark.env('redis') diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 83f233cf..b65416e8 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -293,7 +293,9 @@ class test_Connection: assert not c.is_evented def test_register_with_event_loop(self): - c = Connection(transport=Mock) + transport = Mock(name='transport') + transport.connection_errors = [] + c = Connection(transport=transport) loop = Mock(name='loop') c.register_with_event_loop(loop) c.transport.register_with_event_loop.assert_called_with( @@ -477,7 +479,7 @@ class test_Connection: def publish(): raise _ConnectionError('failed connection') - self.conn.transport.connection_errors = (_ConnectionError,) + self.conn.get_transport_cls().connection_errors = (_ConnectionError,) ensured = self.conn.ensure(self.conn, publish) with pytest.raises(OperationalError): ensured() @@ -485,7 +487,7 @@ class test_Connection: def test_autoretry(self): myfun = Mock() - self.conn.transport.connection_errors = (KeyError,) + self.conn.get_transport_cls().connection_errors = (KeyError,) def on_call(*args, **kwargs): myfun.side_effect = None @@ -571,6 +573,18 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.channel_errors == (KeyError, ValueError) + def test_channel_errors__exception_no_cache(self): + """Ensure the channel_errors can be retrieved without an initialized + transport. + """ + + class MyTransport(Transport): + channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.channel_errors == (KeyError,) + def test_connection_errors(self): class MyTransport(Transport): @@ -579,6 +593,80 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.connection_errors == (KeyError, ValueError) + def test_connection_errors__exception_no_cache(self): + """Ensure the connection_errors can be retrieved without an + initialized transport. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.connection_errors == (KeyError,) + + def test_recoverable_connection_errors(self): + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__fallback(self): + """Ensure missing recoverable_connection_errors on the Transport does + not cause a fatal error. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + channel_errors = (ValueError,) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__exception_no_cache(self): + """Ensure the recoverable_connection_errors can be retrieved without + an initialized transport. + """ + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_connection_errors == (KeyError,) + + def test_recoverable_channel_errors(self): + + class MyTransport(Transport): + recoverable_channel_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == (KeyError, ValueError) + + def test_recoverable_channel_errors__fallback(self): + """Ensure missing recoverable_channel_errors on the Transport does not + cause a fatal error. + """ + + class MyTransport(Transport): + pass + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == () + + def test_recoverable_channel_errors__exception_no_cache(self): + """Ensure the recoverable_channel_errors can be retrieved without an + initialized transport. + """ + class MyTransport(Transport): + recoverable_channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_channel_errors == (KeyError,) + def test_multiple_urls_hostname(self): conn = Connection(['example.com;amqp://example.com']) assert conn.as_uri() == 'amqp://guest:**@example.com:5672//' diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index 1daaf738..7046cf7b 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -269,9 +269,8 @@ class Channel(redis.Channel): class Transport(redis.Transport): Channel = Channel - - def _get_errors(self): - return ((KeyError,), (IndexError,)) + connection_errors = (KeyError,) + channel_errors = (IndexError,) class test_Channel: @@ -907,15 +906,22 @@ class test_Channel: redis.Transport.on_readable(transport, 13) cycle.on_readable.assert_called_with(13) - def test_transport_get_errors(self): - assert redis.Transport._get_errors(self.connection.transport) + def test_transport_connection_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.connection_errors + + def test_transport_channel_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.channel_errors def test_transport_driver_version(self): assert redis.Transport.driver_version(self.connection.transport) - def test_transport_get_errors_when_InvalidData_used(self): + def test_transport_errors_when_InvalidData_used(self): from redis import exceptions + from kombu.transport.redis import get_redis_error_classes + class ID(Exception): pass @@ -924,7 +930,7 @@ class test_Channel: exceptions.InvalidData = ID exceptions.DataError = None try: - errors = redis.Transport._get_errors(self.connection.transport) + errors = get_redis_error_classes() assert errors assert ID in errors[1] finally: