diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 95e539ac..d27cb375 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -462,14 +462,7 @@ class Channel(AbstractChannel, base.StdChannel): typ: cls(self) for typ, cls in self.exchange_types.items() } - try: - self.channel_id = self.connection._avail_channel_ids.pop() - except IndexError: - raise ResourceError( - 'No free channel ids, current={}, channel_max={}'.format( - len(self.connection.channels), - self.connection.channel_max), (20, 10), - ) + self.channel_id = self._get_free_channel_id() topts = self.connection.client.transport_options for opt_name in self.from_transport_options: @@ -844,6 +837,22 @@ class Channel(AbstractChannel, base.StdChannel): return (self.max_priority - priority) if reverse else priority + def _get_free_channel_id(self): + # Cast to a set for fast lookups, and keep stored as an array + # for lower memory usage. + used_channel_ids = set(self.connection._used_channel_ids) + + for channel_id in range(1, self.connection.channel_max + 1): + if channel_id not in used_channel_ids: + self.connection._used_channel_ids.append(channel_id) + return channel_id + + raise ResourceError( + 'No free channel ids, current={}, channel_max={}'.format( + len(self.connection.channels), + self.connection.channel_max), (20, 10), + ) + class Management(base.Management): """Base class for the AMQP management API.""" @@ -907,9 +916,7 @@ class Transport(base.Transport): polling_interval = client.transport_options.get('polling_interval') if polling_interval is not None: self.polling_interval = polling_interval - self._avail_channel_ids = array( - ARRAY_TYPE_H, range(self.channel_max, 0, -1), - ) + self._used_channel_ids = array(ARRAY_TYPE_H) def create_channel(self, connection): try: @@ -921,7 +928,11 @@ class Transport(base.Transport): def close_channel(self, channel): try: - self._avail_channel_ids.append(channel.channel_id) + try: + self._used_channel_ids.remove(channel.channel_id) + except ValueError: + # channel id already removed + pass try: self.channels.remove(channel) except ValueError: diff --git a/t/unit/transport/test_consul.py b/t/unit/transport/test_consul.py index ce6c4fcb..d77b7a41 100644 --- a/t/unit/transport/test_consul.py +++ b/t/unit/transport/test_consul.py @@ -1,3 +1,4 @@ +from array import array from queue import Empty from unittest.mock import Mock @@ -12,6 +13,8 @@ class test_Consul: def setup(self): self.connection = Mock() + self.connection._used_channel_ids = array('H') + self.connection.channel_max = 65535 self.connection.client.transport_options = {} self.connection.client.port = 303 self.consul = self.patching('consul.Consul').return_value diff --git a/t/unit/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index 681841a0..a5685ab9 100644 --- a/t/unit/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -1,6 +1,7 @@ import io import socket import warnings +from array import array from time import monotonic from unittest.mock import MagicMock, Mock, patch @@ -178,13 +179,19 @@ class test_Channel: if self.channel._qos is not None: self.channel._qos._on_collect.cancel() - def test_exceeds_channel_max(self): - c = client() - t = c.transport - avail = t._avail_channel_ids = Mock(name='_avail_channel_ids') - avail.pop.side_effect = IndexError() + def test_get_free_channel_id(self): + conn = client() + channel = conn.channel() + assert channel.channel_id == 1 + assert channel._get_free_channel_id() == 2 + + def test_get_free_channel_id__exceeds_channel_max(self): + conn = client() + conn.transport.channel_max = 2 + channel = conn.channel() + channel._get_free_channel_id() with pytest.raises(ResourceError): - virtual.Channel(t) + channel._get_free_channel_id() def test_exchange_bind_interface(self): with pytest.raises(NotImplementedError): @@ -577,6 +584,23 @@ class test_Transport: del(c1) # so pyflakes doesn't complain del(c2) + def test_create_channel(self): + """Ensure create_channel can create channels successfully.""" + assert self.transport.channels == [] + created_channel = self.transport.create_channel(self.transport) + assert self.transport.channels == [created_channel] + + def test_close_channel(self): + """Ensure close_channel actually removes the channel and updates + _used_channel_ids. + """ + assert self.transport._used_channel_ids == array('H') + created_channel = self.transport.create_channel(self.transport) + assert self.transport._used_channel_ids == array('H', (1,)) + self.transport.close_channel(created_channel) + assert self.transport.channels == [] + assert self.transport._used_channel_ids == array('H') + def test_drain_channel(self): channel = self.transport.create_channel(self.transport) with pytest.raises(virtual.Empty):