Closes cyclic reference in Connection

This commit is contained in:
Ask Solem 2011-04-07 15:53:29 +02:00
parent baeafa0b9f
commit d6601ec92c
6 changed files with 83 additions and 6 deletions

View File

@ -5,11 +5,13 @@ import sys
import time import time
import unittest2 as unittest import unittest2 as unittest
import warnings import warnings
import weakref
from nose import SkipTest from nose import SkipTest
from kombu import BrokerConnection from kombu import BrokerConnection
from kombu import Producer, Consumer, Exchange, Queue from kombu import Producer, Consumer, Exchange, Queue
from kombu.tests.utils import skip_if_quick
if sys.version_info >= (2, 5): if sys.version_info >= (2, 5):
from hashlib import sha256 as _digest from hashlib import sha256 as _digest
@ -119,6 +121,7 @@ class TransportCase(unittest.TestCase):
def _digest(self, data): def _digest(self, data):
return _digest(data).hexdigest() return _digest(data).hexdigest()
@skip_if_quick
def test_produce__consume_large_messages(self, bytes=1048576, n=10, def test_produce__consume_large_messages(self, bytes=1048576, n=10,
charset=string.punctuation + string.letters + string.digits): charset=string.punctuation + string.letters + string.digits):
if not self.verify_alive(): if not self.verify_alive():
@ -156,8 +159,6 @@ class TransportCase(unittest.TestCase):
chan1.close() chan1.close()
self.purge([self.queue.name]) self.purge([self.queue.name])
def P(self, rest): def P(self, rest):
return "%s.%s" % (self.prefix, rest) return "%s.%s" % (self.prefix, rest)
@ -214,6 +215,45 @@ class TransportCase(unittest.TestCase):
self.assertEqual(m.payload, {"basic.get": "this"}) self.assertEqual(m.payload, {"basic.get": "this"})
chan2.close() chan2.close()
def test_cyclic_reference_transport(self):
if not self.verify_alive():
return
def _createref():
conn = self.get_connection()
conn.transport
conn.close()
return weakref.ref(conn)
self.assertIsNone(_createref()())
def test_cyclic_reference_connection(self):
if not self.verify_alive():
return
def _createref():
conn = self.get_connection()
conn.connect()
conn.close()
return weakref.ref(conn)
self.assertIsNone(_createref()())
def test_cyclic_reference_channel(self):
if not self.verify_alive():
return
def _createref():
conn = self.get_connection()
conn.connect()
channel = conn.channel()
channel.close()
conn.close()
return weakref.ref(conn)
self.assertIsNone(_createref()())
def tearDown(self): def tearDown(self):
if self.transport and self.connected: if self.transport and self.connected:
self.connection.close() self.connection.close()

View File

@ -14,7 +14,7 @@ if not os.environ.get("KOMBU_NO_EVAL", False):
from types import ModuleType from types import ModuleType
all_by_module = { all_by_module = {
"kombu.connection": ["BrokerConnection"], "kombu.connection": ["BrokerConnection", "Connection"],
"kombu.entity": ["Exchange", "Queue"], "kombu.entity": ["Exchange", "Queue"],
"kombu.messaging": ["Consumer", "Producer"], "kombu.messaging": ["Consumer", "Producer"],
} }

View File

@ -147,6 +147,9 @@ class BrokerConnection(object):
pass pass
self._connection = None self._connection = None
self._debug("closed") self._debug("closed")
if self._transport:
self._transport.client = None
self._transport = None
self._closed = True self._closed = True
def release(self): def release(self):
@ -456,6 +459,7 @@ class BrokerConnection(object):
def channel_errors(self): def channel_errors(self):
"""List of exceptions that may be raised by the channel.""" """List of exceptions that may be raised by the channel."""
return self.transport.channel_errors return self.transport.channel_errors
Connection = BrokerConnection
class Resource(object): class Resource(object):

View File

@ -1,9 +1,12 @@
import __builtin__ import __builtin__
import os
import sys import sys
import types import types
from StringIO import StringIO from StringIO import StringIO
from nose import SkipTest
from kombu.utils.functional import wraps from kombu.utils.functional import wraps
try: try:
@ -89,3 +92,23 @@ def mask_modules(*modnames):
return __inner return __inner
return _inner return _inner
def skip_if_environ(env_var_name):
def _wrap_test(fun):
@wraps(fun)
def _skips_if_environ(*args, **kwargs):
if os.environ.get(env_var_name):
raise SkipTest("SKIP %s: %s set\n" % (
fun.__name__, env_var_name))
return fun(*args, **kwargs)
return _skips_if_environ
return _wrap_test
def skip_if_quick(fun):
return skip_if_environ("QUICKTEST")(fun)

View File

@ -178,6 +178,12 @@ class Channel(_Channel):
"""Convert encoded message body back to a Python value.""" """Convert encoded message body back to a Python value."""
return self.Message(self, raw_message) return self.Message(self, raw_message)
def close(self):
try:
super(Channel, self).close()
finally:
self.connection = None
class Transport(base.Transport): class Transport(base.Transport):
Connection = Connection Connection = Connection
@ -222,6 +228,7 @@ class Transport(base.Transport):
def close_connection(self, connection): def close_connection(self, connection):
"""Close the AMQP broker connection.""" """Close the AMQP broker connection."""
connection.client = None
connection.close() connection.close()
def verify_connection(self, connection): def verify_connection(self, connection):

View File

@ -619,9 +619,12 @@ class Transport(base.Transport):
def close_channel(self, channel): def close_channel(self, channel):
try: try:
self.channels.remove(channel) try:
except ValueError: self.channels.remove(channel)
pass except ValueError:
pass
finally:
channel.connection = None
def establish_connection(self): def establish_connection(self):
self._avail_channels.append(self.create_channel(self)) self._avail_channels.append(self.create_channel(self))