reduce memory usage of Transport (#1470)

* reduce memory usage of Transport

* fix flake8 errors

* move channel_id login into _get_free_channel_id
This commit is contained in:
Paul Brown 2021-12-23 14:41:30 +00:00 committed by GitHub
parent 4a6e1647b5
commit 507b306400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 18 deletions

View File

@ -462,14 +462,7 @@ class Channel(AbstractChannel, base.StdChannel):
typ: cls(self) for typ, cls in self.exchange_types.items()
}
try:
self.channel_id = self.connection._avail_channel_ids.pop()
except IndexError:
raise ResourceError(
'No free channel ids, current={}, channel_max={}'.format(
len(self.connection.channels),
self.connection.channel_max), (20, 10),
)
self.channel_id = self._get_free_channel_id()
topts = self.connection.client.transport_options
for opt_name in self.from_transport_options:
@ -844,6 +837,22 @@ class Channel(AbstractChannel, base.StdChannel):
return (self.max_priority - priority) if reverse else priority
def _get_free_channel_id(self):
# Cast to a set for fast lookups, and keep stored as an array
# for lower memory usage.
used_channel_ids = set(self.connection._used_channel_ids)
for channel_id in range(1, self.connection.channel_max + 1):
if channel_id not in used_channel_ids:
self.connection._used_channel_ids.append(channel_id)
return channel_id
raise ResourceError(
'No free channel ids, current={}, channel_max={}'.format(
len(self.connection.channels),
self.connection.channel_max), (20, 10),
)
class Management(base.Management):
"""Base class for the AMQP management API."""
@ -907,9 +916,7 @@ class Transport(base.Transport):
polling_interval = client.transport_options.get('polling_interval')
if polling_interval is not None:
self.polling_interval = polling_interval
self._avail_channel_ids = array(
ARRAY_TYPE_H, range(self.channel_max, 0, -1),
)
self._used_channel_ids = array(ARRAY_TYPE_H)
def create_channel(self, connection):
try:
@ -921,7 +928,11 @@ class Transport(base.Transport):
def close_channel(self, channel):
try:
self._avail_channel_ids.append(channel.channel_id)
try:
self._used_channel_ids.remove(channel.channel_id)
except ValueError:
# channel id already removed
pass
try:
self.channels.remove(channel)
except ValueError:

View File

@ -1,3 +1,4 @@
from array import array
from queue import Empty
from unittest.mock import Mock
@ -12,6 +13,8 @@ class test_Consul:
def setup(self):
self.connection = Mock()
self.connection._used_channel_ids = array('H')
self.connection.channel_max = 65535
self.connection.client.transport_options = {}
self.connection.client.port = 303
self.consul = self.patching('consul.Consul').return_value

View File

@ -1,6 +1,7 @@
import io
import socket
import warnings
from array import array
from time import monotonic
from unittest.mock import MagicMock, Mock, patch
@ -178,13 +179,19 @@ class test_Channel:
if self.channel._qos is not None:
self.channel._qos._on_collect.cancel()
def test_exceeds_channel_max(self):
c = client()
t = c.transport
avail = t._avail_channel_ids = Mock(name='_avail_channel_ids')
avail.pop.side_effect = IndexError()
def test_get_free_channel_id(self):
conn = client()
channel = conn.channel()
assert channel.channel_id == 1
assert channel._get_free_channel_id() == 2
def test_get_free_channel_id__exceeds_channel_max(self):
conn = client()
conn.transport.channel_max = 2
channel = conn.channel()
channel._get_free_channel_id()
with pytest.raises(ResourceError):
virtual.Channel(t)
channel._get_free_channel_id()
def test_exchange_bind_interface(self):
with pytest.raises(NotImplementedError):
@ -577,6 +584,23 @@ class test_Transport:
del(c1) # so pyflakes doesn't complain
del(c2)
def test_create_channel(self):
"""Ensure create_channel can create channels successfully."""
assert self.transport.channels == []
created_channel = self.transport.create_channel(self.transport)
assert self.transport.channels == [created_channel]
def test_close_channel(self):
"""Ensure close_channel actually removes the channel and updates
_used_channel_ids.
"""
assert self.transport._used_channel_ids == array('H')
created_channel = self.transport.create_channel(self.transport)
assert self.transport._used_channel_ids == array('H', (1,))
self.transport.close_channel(created_channel)
assert self.transport.channels == []
assert self.transport._used_channel_ids == array('H')
def test_drain_channel(self):
channel = self.transport.create_channel(self.transport)
with pytest.raises(virtual.Empty):