Removes Connection.more_to_read + .nb_keep_draining

This commit is contained in:
Ask Solem 2013-09-30 14:46:25 +01:00
parent ecf1457f13
commit b0234bf840
14 changed files with 92 additions and 122 deletions

View File

@ -37,7 +37,6 @@
.. autoattribute:: connection
.. autoattribute:: uri_prefix
.. autoattribute:: declared_entities
.. autoattribute:: more_to_read
.. autoattribute:: cycle
.. autoattribute:: host
.. autoattribute:: manager
@ -50,7 +49,6 @@
.. automethod:: connect
.. automethod:: channel
.. automethod:: drain_events
.. automethod:: drain_nowait
.. automethod:: release
.. automethod:: autoretry
.. automethod:: ensure_connection

View File

@ -225,7 +225,8 @@ class Hub(object):
self.consolidate.discard(fd)
def _loop(self, propagate=None,
sleep=sleep, min=min, Empty=Empty,
generator=generator, sleep=sleep, min=min, next=next,
Empty=Empty, StopIteration=StopIteration, KeyError=KeyError,
READ=READ, WRITE=WRITE, ERR=ERR):
readers, writers = self.readers, self.writers
poll = self.poller.poll
@ -235,10 +236,20 @@ class Hub(object):
consolidate = self.consolidate
consolidate_callback = self.consolidate_callback
on_tick = self.on_tick
remove_ticks = on_tick.difference_update
while 1:
outdated_ticks = set()
for tick_callback in on_tick:
tick_callback()
try:
if isinstance(tick_callback, generator):
next(tick_callback)
else:
tick_callback()
except StopIteration:
outdated_ticks.add(tick_callback)
remove_ticks(outdated_ticks)
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
#print('[[[HUB]]]: %s' % (self.repr_active(), ))
if readers or writers:
@ -279,7 +290,7 @@ class Hub(object):
raise
else:
try:
cb(fileno, event, *cbargs)
cb(*cbargs)
except Empty:
pass
if to_consolidate:

View File

@ -7,7 +7,6 @@ Broker connection and pools.
"""
from __future__ import absolute_import
import errno
import os
import socket
@ -29,7 +28,7 @@ from .five import Empty, range, string_t, text_t, LifoQueue as _LifoQueue
from .log import get_logger
from .transport import get_transport_cls, supports_librabbitmq
from .utils import cached_property, retry_over_time, shufflecycle
from .utils.compat import OrderedDict, get_errno
from .utils.compat import OrderedDict
from .utils.functional import lazy
from .utils.url import parse_url
@ -128,10 +127,6 @@ class Connection(object):
#: in case the server loses data.
declared_entities = None
#: This is set to True if there is still more data to read
#: after a call to :meth:`drain_nowait`.
more_to_read = False
#: Iterator returning the next broker URL to try in the event
#: of connection failure (initialized by :attr:`failover_strategy`).
cycle = None
@ -287,38 +282,6 @@ class Connection(object):
"""
return self.transport.drain_events(self.connection, **kwargs)
def drain_nowait(self, *args, **kwargs):
"""Non-blocking version of :meth:`drain_events`.
Sets :attr:`more_to_read` if there is more data to read.
The application MUST call this method until this is unset, and before
calling select/epoll/kqueue's poll() again.
"""
try:
self.drain_events(timeout=0)
except socket.timeout:
self.more_to_read = False
return False
except socket.error as exc:
if get_errno(exc) in (errno.EAGAIN, errno.EINTR):
self.more_to_read = False
return False
raise
self.more_to_read = True
return True
def drain_nowait_all(self, *args, **kwargs):
while 1:
try:
self.drain_events(timeout=0)
except socket.timeout:
break
except socket.error as exc:
if get_errno(exc) in (errno.EGAIN, errno.EINTR):
break
raise
def maybe_close_channel(self, channel):
"""Close given channel, but ignore connection and channel errors."""
try:

View File

@ -6,7 +6,7 @@ from kombu.async.hub import (
maybe_block, is_in_blocking_section,
)
from kombu.tests.case import Case, ContextMock, Mock
from kombu.tests.case import Case, ContextMock
class test_Utils(Case):
@ -63,4 +63,3 @@ class test_Hub(Case):
with self.hub.maybe_block():
self.assertTrue(self.hub.in_blocking_section)
self.assertFalse(self.hub.in_blocking_section)

View File

@ -1,6 +1,5 @@
from __future__ import absolute_import
import errno
import pickle
import socket
@ -261,32 +260,6 @@ class test_Connection(Case):
self.assertEqual(cb(KeyError(), intervals, 0), 0)
self.assertTrue(errback.called)
def test_drain_nowait(self):
c = Connection(transport=Mock)
c.drain_events = Mock()
c.drain_events.side_effect = socket.timeout()
c.more_to_read = True
self.assertFalse(c.drain_nowait())
self.assertFalse(c.more_to_read)
c.drain_events.side_effect = socket.error()
c.drain_events.side_effect.errno = errno.EAGAIN
c.more_to_read = True
self.assertFalse(c.drain_nowait())
self.assertFalse(c.more_to_read)
c.drain_events.side_effect = socket.error()
c.drain_events.side_effect.errno = errno.EPERM
with self.assertRaises(socket.error):
c.drain_nowait()
c.more_to_read = False
c.drain_events = Mock()
self.assertTrue(c.drain_nowait())
c.drain_events.assert_called_with(timeout=0)
self.assertTrue(c.more_to_read)
def test_supports_heartbeats(self):
c = Connection(transport=Mock)
c.transport.supports_heartbeats = False

View File

@ -134,7 +134,7 @@ class test_Transport(lrmqCase):
loop = Mock(name='loop')
self.T.register_with_event_loop(conn, loop)
loop.add_reader.assert_called_with(
conn.fileno(), self.T.client.drain_nowait_all,
conn.fileno(), self.T.on_readable, conn, loop,
)
def test_verify_connection(self):

View File

@ -163,7 +163,7 @@ class test_pyamqp(Case):
loop = Mock(name='loop')
t.register_with_event_loop(conn, loop)
loop.add_reader.assert_called_with(
conn.sock, t.client.drain_nowait_all,
conn.sock, t.on_readable, conn, loop,
)
def test_heartbeat_check(self):

View File

@ -655,26 +655,28 @@ class test_Channel(Case):
on_poll_start()
transport.cycle.on_poll_start.assert_called_with()
loop.add_reader.assert_has_calls([
call(12, transport.handle_event), call(13, transport.handle_event),
call(12, transport.on_readable, 12),
call(13, transport.on_readable, 13),
])
def test_transport_handle_event(self):
def test_transport_on_readable(self):
transport = self.connection.transport
cycle = transport.cycle = Mock(name='cyle')
cycle.handle_event.return_value = None
cycle.on_readable.return_value = None
redis.Transport.handle_event(transport, 13, redis.READ)
cycle.handle_event.assert_called_with(13, redis.READ)
cycle.handle_event.reset_mock()
redis.Transport.on_readable(transport, 13)
cycle.on_readable.assert_called_with(13)
cycle.on_readable.reset_mock()
ret = (Mock(name='message'), Mock(name='queue')), Mock(name='channel')
cycle.handle_event.return_value = ret
queue = Mock(name='queue')
ret = (Mock(name='message'), queue)
cycle.on_readable.return_value = ret
with self.assertRaises(KeyError):
redis.Transport.handle_event(transport, 14, redis.READ)
redis.Transport.on_readable(transport, 14)
cb = transport._callbacks[ret[0][1]] = Mock(name='callback')
redis.Transport.handle_event(transport, 14, redis.READ)
cb.assert_called_with(ret[0][0])
cb = transport._callbacks[queue] = Mock(name='callback')
redis.Transport.on_readable(transport, 14)
cb.assert_called_with(ret[0])
@skip_if_not_module('redis')
def test_transport_get_errors(self):
@ -762,7 +764,7 @@ class test_Redis(Case):
connection = Connection(transport=Transport)
channel = connection.channel()
producer = Producer(channel, self.exchange, routing_key='test_Redis')
consumer = Consumer(channel, self.queue)
consumer = Consumer(channel, queues=[self.queue])
producer.publish({'hello2': 'world2'})
_received = []

View File

@ -311,14 +311,13 @@ class Transport(base.Transport):
)
channel_errors = base.Transport.channel_errors + (AMQPChannelException, )
nb_keep_draining = True
driver_name = "amqplib"
driver_type = "amqp"
driver_name = 'amqplib'
driver_type = 'amqp'
supports_ev = True
def __init__(self, client, **kwargs):
self.client = client
self.default_port = kwargs.get("default_port") or self.default_port
self.default_port = kwargs.get('default_port') or self.default_port
def create_channel(self, connection):
return connection.channel()
@ -370,7 +369,7 @@ class Transport(base.Transport):
def register_with_event_loop(self, connection, loop):
loop.add_reader(connection.method_reader.source.sock,
self.client.drain_nowait_all)
self.on_readable, connection, loop)
@property
def default_connection_params(self):

View File

@ -7,9 +7,13 @@ Base transport interface.
"""
from __future__ import absolute_import
import errno
import socket
from kombu.exceptions import ChannelError, ConnectionError
from kombu.message import Message
from kombu.utils import cached_property
from kombu.utils.compat import get_errno
__all__ = ['Message', 'StdChannel', 'Management', 'Transport']
@ -71,10 +75,6 @@ class Transport(object):
#: Tuple of errors that can happen due to channel/method failure.
channel_errors = (ChannelError, )
#: For non-blocking use, an eventloop should keep
#: draining events as long as ``connection.more_to_read`` is True.
nb_keep_draining = False
#: Type of driver, can be used to separate transports
#: using the AMQP protocol (driver_type: 'amqp'),
#: Redis (driver_type: 'redis'), etc...
@ -90,6 +90,8 @@ class Transport(object):
#: Set to true if the transport supports the AIO interface.
supports_ev = False
__reader = None
def __init__(self, client, **kwargs):
self.client = client
@ -120,6 +122,30 @@ class Transport(object):
def verify_connection(self, connection):
return True
def _reader(self, connection, timeout=socket.timeout, error=socket.error,
get_errno=get_errno, _unavail=(errno.EAGAIN, errno.EINTR)):
drain_events = connection.drain_events
while 1:
try:
yield drain_events(timeout=0)
except timeout:
break
except error as exc:
if get_errno(exc) in _unavail:
break
raise
def on_readable(self, connection, loop):
reader = self.__reader
if reader is None:
reader = self.__reader = self._reader(connection)
try:
next(reader)
except StopIteration:
reader = self.__reader = self._reader(connection)
next(reader, None)
loop.on_tick.add(reader)
@property
def default_connection_params(self):
return {}

View File

@ -83,11 +83,11 @@ class Transport(base.Transport):
driver_name = 'librabbitmq'
supports_ev = True
nb_keep_draining = True
def __init__(self, client, **kwargs):
self.client = client
self.default_port = kwargs.get('default_port') or self.default_port
self.__reader = None
def driver_version(self):
return amqp.__version__
@ -143,7 +143,9 @@ class Transport(base.Transport):
return connection.connected
def register_with_event_loop(self, connection, loop):
loop.add_reader(connection.fileno(), self.client.drain_nowait_all)
loop.add_reader(
connection.fileno(), self.on_readable, connection, loop,
)
def get_manager(self, *args, **kwargs):
return get_manager(self.client, *args, **kwargs)

View File

@ -71,7 +71,6 @@ class Transport(base.Transport):
amqp.Connection.recoverable_connection_errors
recoverable_channel_errors = amqp.Connection.recoverable_channel_errors
nb_keep_draining = True
driver_name = 'py-amqp'
driver_type = 'amqp'
supports_heartbeats = True
@ -119,7 +118,7 @@ class Transport(base.Transport):
connection.close()
def register_with_event_loop(self, connection, loop):
loop.add_reader(connection.sock, self.client.drain_nowait_all)
loop.add_reader(connection.sock, self.on_readable, connection, loop)
def heartbeat_check(self, connection, rate=2):
return connection.heartbeat_tick(rate=rate)

View File

@ -124,7 +124,7 @@ class QoS(virtual.QoS):
def reject(self, delivery_tag, requeue=False):
if requeue:
self.restore_by_tag(tag, leftmost=True)
self.restore_by_tag(delivery_tag, leftmost=True)
self.ack(delivery_tag)
@contextmanager
@ -271,11 +271,14 @@ class MultiChannelPoller(object):
num=channel.unacked_restore_limit,
)
def on_readable(self, fileno):
chan, type = self._fd_to_chan[fileno]
if chan.qos.can_consume():
return chan.handlers[type]()
def handle_event(self, fileno, event):
if event & READ:
chan, type = self._fd_to_chan[fileno]
if chan.qos.can_consume():
return chan.handlers[type](), self
return self.on_readable(fileno), self
elif event & ERR:
chan, type = self._fd_to_chan[fileno]
chan._poll_error(type)
@ -763,19 +766,18 @@ class Transport(virtual.Transport):
cycle.on_poll_init(loop.poller)
cycle_poll_start = cycle.on_poll_start
add_reader = loop.add_reader
handle_event = self.handle_event
on_readable = self.on_readable
def on_poll_start():
cycle_poll_start()
[add_reader(fd, handle_event) for fd in cycle.fds]
[add_reader(fd, on_readable, fd) for fd in cycle.fds]
loop.on_tick.add(on_poll_start)
loop.call_repeatedly(10, cycle.maybe_restore_messages)
def handle_event(self, fileno, event):
def on_readable(self, fileno):
"""Handle AIO event for one of our file descriptors."""
ret = self.cycle.handle_event(fileno, event)
if ret:
item, channel = ret
item = self.cycle.on_readable(fileno)
if item:
message, queue = item
if not queue or queue not in self._callbacks:
raise KeyError(

View File

@ -71,16 +71,16 @@ class MultiChannelPoller(object):
for channel in self._channels:
self._register(channel)
def handle_event(self, fileno, event):
def on_readable(self, fileno):
chan = self._fd_to_chan[fileno]
return (chan.drain_events(), chan)
return chan.drain_events(), chan
def get(self, timeout=None):
self.on_poll_start()
events = self.poller.poll(timeout)
for fileno, event in events or []:
return self.handle_event(fileno, event)
for fileno, _ in events or []:
return self.on_readable(fileno)
raise Empty()
@ -238,7 +238,6 @@ class Transport(virtual.Transport):
supports_ev = True
polling_interval = None
nb_keep_draining = True
def __init__(self, *args, **kwargs):
if zmq is None:
@ -253,21 +252,18 @@ class Transport(virtual.Transport):
cycle = self.cycle
cycle.poller = loop.poller
add_reader = loop.add_reader
handle_event = self.handle_event
on_readable = self.on_readable
cycle_poll_start = cycle.on_poll_start
def on_poll_start():
cycle_poll_start()
[add_reader(fd, handle_event) for fd in cycle.fds]
for fd in cycle.fds:
add_reader(fd, handle_event)
[add_reader(fd, on_readable, fd) for fd in cycle.fds]
loop.on_tick.add(on_poll_start)
def handle_event(self, fileno, event):
evt = self.cycle.handle_event(fileno, event)
self._handle_event(evt)
def on_readable(self, fileno):
self._handle_event(self.cycle.on_readable(fileno))
def drain_events(self, connection, timeout=None):
more_to_read = False