allow getting recoverable_connection_errors without an active transport (#1471)

* allow getting recoverable_connection_errors without an active transport

* move redis transport errors to class

* move consul transport errors to class

* move etcd transport errors to class

* remove redis.Transport._get_errors and references in tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix flake8 errors

* add integration test for redis ConnectionError

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Paul Brown 2021-12-30 06:28:11 +00:00 committed by GitHub
parent b6b4408575
commit 9c062bdca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 135 additions and 37 deletions

View File

@ -932,7 +932,7 @@ class Connection:
but where the connection must be closed and re-established first.
"""
try:
return self.transport.recoverable_connection_errors
return self.get_transport_cls().recoverable_connection_errors
except AttributeError:
# There were no such classification before,
# and all errors were assumed to be recoverable,
@ -948,19 +948,19 @@ class Connection:
recovered from without re-establishing the connection.
"""
try:
return self.transport.recoverable_channel_errors
return self.get_transport_cls().recoverable_channel_errors
except AttributeError:
return ()
@cached_property
def connection_errors(self):
"""List of exceptions that may be raised by the connection."""
return self.transport.connection_errors
return self.get_transport_cls().connection_errors
@cached_property
def channel_errors(self):
"""List of exceptions that may be raised by the channel."""
return self.transport.channel_errors
return self.get_transport_cls().channel_errors
@property
def supports_heartbeats(self):

View File

@ -276,24 +276,25 @@ class Transport(virtual.Transport):
driver_type = 'consul'
driver_name = 'consul'
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
self.connection_errors = (
if consul:
connection_errors = (
virtual.Transport.connection_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
self.channel_errors = (
channel_errors = (
virtual.Transport.channel_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')
super().__init__(*args, **kwargs)
def verify_connection(self, connection):
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST

View File

@ -242,6 +242,15 @@ class Transport(virtual.Transport):
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct']))
if etcd:
connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)
channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)
def __init__(self, *args, **kwargs):
"""Create a new instance of etcd.Transport."""
if etcd is None:
@ -249,14 +258,6 @@ class Transport(virtual.Transport):
super().__init__(*args, **kwargs)
self.connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)
self.channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)
def verify_connection(self, connection):
"""Verify the connection works."""
port = connection.client.port or self.default_port

View File

@ -1214,13 +1214,14 @@ class Transport(virtual.Transport):
exchange_type=frozenset(['direct', 'topic', 'fanout'])
)
if redis:
connection_errors, channel_errors = get_redis_error_classes()
def __init__(self, *args, **kwargs):
if redis is None:
raise ImportError('Missing redis library (pip install redis)')
super().__init__(*args, **kwargs)
# Get redis-py exceptions.
self.connection_errors, self.channel_errors = self._get_errors()
# All channels share the same poller.
self.cycle = MultiChannelPoller()
@ -1265,10 +1266,6 @@ class Transport(virtual.Transport):
"""Handle AIO event for one of our file descriptors."""
self.cycle.on_readable(fileno)
def _get_errors(self):
"""Utility to import redis-py's exceptions at runtime."""
return get_redis_error_classes()
if sentinel:
class SentinelManagedSSLConnection(

View File

@ -5,6 +5,7 @@ import pytest
import redis
import kombu
from kombu.transport.redis import Transport
from .common import (BaseExchangeTypes, BaseMessage, BasePriority,
BasicFunctionality)
@ -56,7 +57,11 @@ def test_failed_credentials():
@pytest.mark.env('redis')
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_RedisBasicFunctionality(BasicFunctionality):
pass
def test_failed_connection__ConnectionError(self, invalid_connection):
# method raises transport exception
with pytest.raises(redis.exceptions.ConnectionError) as ex:
invalid_connection.connection
assert ex.type in Transport.connection_errors
@pytest.mark.env('redis')

View File

@ -293,7 +293,9 @@ class test_Connection:
assert not c.is_evented
def test_register_with_event_loop(self):
c = Connection(transport=Mock)
transport = Mock(name='transport')
transport.connection_errors = []
c = Connection(transport=transport)
loop = Mock(name='loop')
c.register_with_event_loop(loop)
c.transport.register_with_event_loop.assert_called_with(
@ -477,7 +479,7 @@ class test_Connection:
def publish():
raise _ConnectionError('failed connection')
self.conn.transport.connection_errors = (_ConnectionError,)
self.conn.get_transport_cls().connection_errors = (_ConnectionError,)
ensured = self.conn.ensure(self.conn, publish)
with pytest.raises(OperationalError):
ensured()
@ -485,7 +487,7 @@ class test_Connection:
def test_autoretry(self):
myfun = Mock()
self.conn.transport.connection_errors = (KeyError,)
self.conn.get_transport_cls().connection_errors = (KeyError,)
def on_call(*args, **kwargs):
myfun.side_effect = None
@ -571,6 +573,18 @@ class test_Connection:
conn = Connection(transport=MyTransport)
assert conn.channel_errors == (KeyError, ValueError)
def test_channel_errors__exception_no_cache(self):
"""Ensure the channel_errors can be retrieved without an initialized
transport.
"""
class MyTransport(Transport):
channel_errors = (KeyError,)
conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.channel_errors == (KeyError,)
def test_connection_errors(self):
class MyTransport(Transport):
@ -579,6 +593,80 @@ class test_Connection:
conn = Connection(transport=MyTransport)
assert conn.connection_errors == (KeyError, ValueError)
def test_connection_errors__exception_no_cache(self):
"""Ensure the connection_errors can be retrieved without an
initialized transport.
"""
class MyTransport(Transport):
connection_errors = (KeyError,)
conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.connection_errors == (KeyError,)
def test_recoverable_connection_errors(self):
class MyTransport(Transport):
recoverable_connection_errors = (KeyError, ValueError)
conn = Connection(transport=MyTransport)
assert conn.recoverable_connection_errors == (KeyError, ValueError)
def test_recoverable_connection_errors__fallback(self):
"""Ensure missing recoverable_connection_errors on the Transport does
not cause a fatal error.
"""
class MyTransport(Transport):
connection_errors = (KeyError,)
channel_errors = (ValueError,)
conn = Connection(transport=MyTransport)
assert conn.recoverable_connection_errors == (KeyError, ValueError)
def test_recoverable_connection_errors__exception_no_cache(self):
"""Ensure the recoverable_connection_errors can be retrieved without
an initialized transport.
"""
class MyTransport(Transport):
recoverable_connection_errors = (KeyError,)
conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.recoverable_connection_errors == (KeyError,)
def test_recoverable_channel_errors(self):
class MyTransport(Transport):
recoverable_channel_errors = (KeyError, ValueError)
conn = Connection(transport=MyTransport)
assert conn.recoverable_channel_errors == (KeyError, ValueError)
def test_recoverable_channel_errors__fallback(self):
"""Ensure missing recoverable_channel_errors on the Transport does not
cause a fatal error.
"""
class MyTransport(Transport):
pass
conn = Connection(transport=MyTransport)
assert conn.recoverable_channel_errors == ()
def test_recoverable_channel_errors__exception_no_cache(self):
"""Ensure the recoverable_channel_errors can be retrieved without an
initialized transport.
"""
class MyTransport(Transport):
recoverable_channel_errors = (KeyError,)
conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.recoverable_channel_errors == (KeyError,)
def test_multiple_urls_hostname(self):
conn = Connection(['example.com;amqp://example.com'])
assert conn.as_uri() == 'amqp://guest:**@example.com:5672//'

View File

@ -269,9 +269,8 @@ class Channel(redis.Channel):
class Transport(redis.Transport):
Channel = Channel
def _get_errors(self):
return ((KeyError,), (IndexError,))
connection_errors = (KeyError,)
channel_errors = (IndexError,)
class test_Channel:
@ -907,15 +906,22 @@ class test_Channel:
redis.Transport.on_readable(transport, 13)
cycle.on_readable.assert_called_with(13)
def test_transport_get_errors(self):
assert redis.Transport._get_errors(self.connection.transport)
def test_transport_connection_errors(self):
"""Ensure connection_errors are populated."""
assert redis.Transport.connection_errors
def test_transport_channel_errors(self):
"""Ensure connection_errors are populated."""
assert redis.Transport.channel_errors
def test_transport_driver_version(self):
assert redis.Transport.driver_version(self.connection.transport)
def test_transport_get_errors_when_InvalidData_used(self):
def test_transport_errors_when_InvalidData_used(self):
from redis import exceptions
from kombu.transport.redis import get_redis_error_classes
class ID(Exception):
pass
@ -924,7 +930,7 @@ class test_Channel:
exceptions.InvalidData = ID
exceptions.DataError = None
try:
errors = redis.Transport._get_errors(self.connection.transport)
errors = get_redis_error_classes()
assert errors
assert ID in errors[1]
finally: