diff --git a/docs/userguide/connections.rst b/docs/userguide/connections.rst index 73d08f29..f97b4b79 100644 --- a/docs/userguide/connections.rst +++ b/docs/userguide/connections.rst @@ -70,6 +70,9 @@ All of these are valid URLs:: # Using Redis redis://localhost:6379/ + # Using Redis over a Unix socket + redis+socket:///tmp/redis.sock + # Using virtual host '/foo' amqp://localhost//foo diff --git a/kombu/connection.py b/kombu/connection.py index 7dce4c8d..9133b050 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -579,7 +579,8 @@ class Connection(object): def as_uri(self, include_password=False): """Convert connection parameters to URL form.""" - if self.transport_cls in URI_PASSTHROUGH: + if (self.transport_cls in URI_PASSTHROUGH or + self.hostname.startswith('socket://')): return self.transport_cls + '+' + (self.hostname or 'localhost') quoteS = partial(quote, safe='') # strict quote fields = self.info() diff --git a/kombu/tests/transport/test_redis.py b/kombu/tests/transport/test_redis.py index fb71039b..1a8bd9d9 100644 --- a/kombu/tests/transport/test_redis.py +++ b/kombu/tests/transport/test_redis.py @@ -477,6 +477,15 @@ class test_Channel(TestCase): with self.assertRaises(InconsistencyError): self.channel.get_table('celery') + @skip_if_not_module('redis') + def test_socket_connection(self): + connection = Connection('redis+socket:///tmp/redis.sock', + transport=Transport) + connparams = connection.channel()._connparams() + self.assertEqual(connparams['connection_class'], + redis.redis.UnixDomainSocketConnection) + self.assertEqual(connparams['path'], '/tmp/redis.sock') + class test_Redis(TestCase): diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 49bc3bfa..21b7e047 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -623,12 +623,19 @@ class Channel(virtual.Channel): except ValueError: raise ValueError( 'Database name must be int between 0 and limit - 1') - return {'host': conninfo.hostname or '127.0.0.1', - 'port': conninfo.port or DEFAULT_PORT, - 'db': database, - 'password': conninfo.password, - 'max_connections': self.max_connections, - 'socket_timeout': self.socket_timeout} + connparams = {'host': conninfo.hostname or '127.0.0.1', + 'port': conninfo.port or DEFAULT_PORT, + 'db': database, + 'password': conninfo.password, + 'max_connections': self.max_connections, + 'socket_timeout': self.socket_timeout} + if conninfo.hostname.split('://')[0] == 'socket': + connparams.update({ + 'connection_class': redis.UnixDomainSocketConnection, + 'path': conninfo.hostname.split('://')[1]}) + connparams.pop('host', None) + connparams.pop('port', None) + return connparams def _create_client(self): return self.Client(connection_pool=self.pool)