Refactor chroot normalization

This commit is contained in:
Dima Kurguzov 2017-04-21 10:57:54 +03:00 committed by George Psarakis
parent 5f35ef996c
commit 221685618f
2 changed files with 17 additions and 20 deletions

View File

@ -29,6 +29,7 @@ import socket
from kombu.five import Empty from kombu.five import Empty
from kombu.utils.encoding import bytes_to_str, ensure_bytes from kombu.utils.encoding import bytes_to_str, ensure_bytes
from kombu.utils.json import dumps, loads from kombu.utils.json import dumps, loads
from . import virtual from . import virtual
try: try:
@ -66,8 +67,8 @@ try:
socket.error, socket.error,
) )
except ImportError: except ImportError:
kazoo = None # noqa kazoo = None # noqa
KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = () # noqa KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = () # noqa
DEFAULT_PORT = 2181 DEFAULT_PORT = 2181
@ -80,8 +81,13 @@ class Channel(virtual.Channel):
_client = None _client = None
_queues = {} _queues = {}
def __init__(self, connection, **kwargs):
super(Channel, self).__init__(connection, **kwargs)
vhost = self.connection.client.virtual_host
self._vhost = '/{}'.format(vhost.strip('/'))
def _get_path(self, queue_name): def _get_path(self, queue_name):
return os.path.join(self.vhost, queue_name) return os.path.join(self._vhost, queue_name)
def _get_queue(self, queue_name): def _get_queue(self, queue_name):
queue = self._queues.get(queue_name, None) queue = self._queues.get(queue_name, None)
@ -140,7 +146,6 @@ class Channel(virtual.Channel):
def _open(self): def _open(self):
conninfo = self.connection.client conninfo = self.connection.client
self.vhost = self._normalize_chroot(conninfo.virtual_host)
hosts = [] hosts = []
if conninfo.alt: if conninfo.alt:
for host_port in conninfo.alt: for host_port in conninfo.alt:
@ -165,13 +170,6 @@ class Channel(virtual.Channel):
conn.start() conn.start()
return conn return conn
@staticmethod
def _normalize_chroot(chroot):
chroot = chroot.rstrip('/')
if not len(chroot) or chroot[0] != '/':
chroot = '/' + chroot
return chroot
@property @property
def client(self): def client(self):
if self._client is None: if self._client is None:

View File

@ -2,7 +2,6 @@ from __future__ import absolute_import, unicode_literals
import pytest import pytest
from case import skip from case import skip
from kombu import Connection from kombu import Connection
from kombu.transport import zookeeper from kombu.transport import zookeeper
@ -27,11 +26,11 @@ class test_Channel:
self.channel._queues['foo'] = AssertQueue() self.channel._queues['foo'] = AssertQueue()
self.channel._put(queue='foo', message='bar') self.channel._put(queue='foo', message='bar')
@pytest.mark.parametrize('input,expected', (
@pytest.mark.parametrize('input,expected', ( ('', '/'),
('/', '/'), ('/root', '/root'),
('/root', '/root'), ('/root/', '/root'),
('/root/', '/root'), ))
)) def test_virtual_host_normalization(self, input, expected):
def test_normalize_chroot(input, expected): with self.create_connection(virtual_host=input) as conn:
assert zookeeper.Channel._normalize_chroot(input) == expected assert conn.default_channel._vhost == expected