diff --git a/kombu/transport/zookeeper.py b/kombu/transport/zookeeper.py index 027afc68..d00c60e3 100644 --- a/kombu/transport/zookeeper.py +++ b/kombu/transport/zookeeper.py @@ -29,6 +29,7 @@ import socket from kombu.five import Empty from kombu.utils.encoding import bytes_to_str, ensure_bytes from kombu.utils.json import dumps, loads + from . import virtual try: @@ -66,8 +67,8 @@ try: socket.error, ) except ImportError: - kazoo = None # noqa - KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = () # noqa + kazoo = None # noqa + KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = () # noqa DEFAULT_PORT = 2181 @@ -80,8 +81,13 @@ class Channel(virtual.Channel): _client = None _queues = {} + def __init__(self, connection, **kwargs): + super(Channel, self).__init__(connection, **kwargs) + vhost = self.connection.client.virtual_host + self._vhost = '/{}'.format(vhost.strip('/')) + def _get_path(self, queue_name): - return os.path.join(self.vhost, queue_name) + return os.path.join(self._vhost, queue_name) def _get_queue(self, queue_name): queue = self._queues.get(queue_name, None) @@ -140,7 +146,6 @@ class Channel(virtual.Channel): def _open(self): conninfo = self.connection.client - self.vhost = self._normalize_chroot(conninfo.virtual_host) hosts = [] if conninfo.alt: for host_port in conninfo.alt: @@ -165,13 +170,6 @@ class Channel(virtual.Channel): conn.start() return conn - @staticmethod - def _normalize_chroot(chroot): - chroot = chroot.rstrip('/') - if not len(chroot) or chroot[0] != '/': - chroot = '/' + chroot - return chroot - @property def client(self): if self._client is None: diff --git a/t/unit/transport/test_zookeeper.py b/t/unit/transport/test_zookeeper.py index 03ba5134..7505a6d9 100644 --- a/t/unit/transport/test_zookeeper.py +++ b/t/unit/transport/test_zookeeper.py @@ -2,7 +2,6 @@ from __future__ import absolute_import, unicode_literals import pytest from case import skip - from kombu import Connection from kombu.transport import zookeeper @@ -27,11 +26,11 @@ class test_Channel: self.channel._queues['foo'] = AssertQueue() self.channel._put(queue='foo', message='bar') - -@pytest.mark.parametrize('input,expected', ( - ('/', '/'), - ('/root', '/root'), - ('/root/', '/root'), -)) -def test_normalize_chroot(input, expected): - assert zookeeper.Channel._normalize_chroot(input) == expected + @pytest.mark.parametrize('input,expected', ( + ('', '/'), + ('/root', '/root'), + ('/root/', '/root'), + )) + def test_virtual_host_normalization(self, input, expected): + with self.create_connection(virtual_host=input) as conn: + assert conn.default_channel._vhost == expected