Added Unix socket connection support to the redis transport

cf. https://github.com/celery/celery/issues/1283
and https://github.com/piquadrat/celery-redis-unixsocket/issues/3
This commit is contained in:
Maxime Rouyrre 2013-06-13 11:00:07 +02:00
parent 9204bed46d
commit f30f669adc
4 changed files with 27 additions and 7 deletions

View File

@ -70,6 +70,9 @@ All of these are valid URLs::
# Using Redis # Using Redis
redis://localhost:6379/ redis://localhost:6379/
# Using Redis over a Unix socket
redis+socket:///tmp/redis.sock
# Using virtual host '/foo' # Using virtual host '/foo'
amqp://localhost//foo amqp://localhost//foo

View File

@ -579,7 +579,8 @@ class Connection(object):
def as_uri(self, include_password=False): def as_uri(self, include_password=False):
"""Convert connection parameters to URL form.""" """Convert connection parameters to URL form."""
if self.transport_cls in URI_PASSTHROUGH: if (self.transport_cls in URI_PASSTHROUGH or
self.hostname.startswith('socket://')):
return self.transport_cls + '+' + (self.hostname or 'localhost') return self.transport_cls + '+' + (self.hostname or 'localhost')
quoteS = partial(quote, safe='') # strict quote quoteS = partial(quote, safe='') # strict quote
fields = self.info() fields = self.info()

View File

@ -477,6 +477,15 @@ class test_Channel(TestCase):
with self.assertRaises(InconsistencyError): with self.assertRaises(InconsistencyError):
self.channel.get_table('celery') self.channel.get_table('celery')
@skip_if_not_module('redis')
def test_socket_connection(self):
connection = Connection('redis+socket:///tmp/redis.sock',
transport=Transport)
connparams = connection.channel()._connparams()
self.assertEqual(connparams['connection_class'],
redis.redis.UnixDomainSocketConnection)
self.assertEqual(connparams['path'], '/tmp/redis.sock')
class test_Redis(TestCase): class test_Redis(TestCase):

View File

@ -623,12 +623,19 @@ class Channel(virtual.Channel):
except ValueError: except ValueError:
raise ValueError( raise ValueError(
'Database name must be int between 0 and limit - 1') 'Database name must be int between 0 and limit - 1')
return {'host': conninfo.hostname or '127.0.0.1', connparams = {'host': conninfo.hostname or '127.0.0.1',
'port': conninfo.port or DEFAULT_PORT, 'port': conninfo.port or DEFAULT_PORT,
'db': database, 'db': database,
'password': conninfo.password, 'password': conninfo.password,
'max_connections': self.max_connections, 'max_connections': self.max_connections,
'socket_timeout': self.socket_timeout} 'socket_timeout': self.socket_timeout}
if conninfo.hostname.split('://')[0] == 'socket':
connparams.update({
'connection_class': redis.UnixDomainSocketConnection,
'path': conninfo.hostname.split('://')[1]})
connparams.pop('host', None)
connparams.pop('port', None)
return connparams
def _create_client(self): def _create_client(self):
return self.Client(connection_pool=self.pool) return self.Client(connection_pool=self.pool)