Adds Consumer.accept: A whitelist of content_types/serializer names to accept

This commit is contained in:
Ask Solem 2013-04-11 14:59:07 +01:00
parent b1a8d2f868
commit 24e441c79c
4 changed files with 46 additions and 14 deletions

View File

@ -63,6 +63,11 @@ class SerializerNotInstalled(KombuError):
pass
class ContentDisallowed(SerializerNotInstalled):
"""Consumer does not allow this content-type."""
pass
class InconsistencyError(StdConnectionError):
"""Data or environment has been found to be inconsistent,
depending on the cause it may be possible to retry the operation."""

View File

@ -318,16 +318,27 @@ class Consumer(object):
#: that occurred while trying to decode it.
on_decode_error = None
#: List of accepted content-types.
#:
#: An exception will be raised if the consumer receives
#: a message with an untrusted content type.
#: By default all content-types are accepted, but not if
#: :func:`kombu.disable_untrusted_serializers` was called,
#: in which case only json is allowed.
accept = None
_next_tag = count(1).next # global
def __init__(self, channel, queues=None, no_ack=None, auto_declare=None,
callbacks=None, on_decode_error=None, on_message=None):
callbacks=None, on_decode_error=None, on_message=None,
accept=None):
self.channel = channel
self.queues = self.queues or [] if queues is None else queues
self.no_ack = self.no_ack if no_ack is None else no_ack
self.callbacks = (self.callbacks or [] if callbacks is None
else callbacks)
self.on_message = on_message
self.accept = accept
self._active_tags = {}
if auto_declare is not None:
self.auto_declare = auto_declare
@ -528,6 +539,9 @@ class Consumer(object):
return tag
def _receive_callback(self, message):
accept = self.accept
if accept is not None:
message.accept = accept
on_m, channel, decoded = self.on_message, self.channel, None
try:
m2p = getattr(channel, 'message_to_python', None)

View File

@ -18,7 +18,7 @@ try:
except ImportError: # pragma: no cover
cpickle = None # noqa
from .exceptions import SerializerNotInstalled
from .exceptions import SerializerNotInstalled, ContentDisallowed
from .utils import entrypoints
from .utils.encoding import str_to_bytes, bytes_t
@ -74,6 +74,10 @@ def pickle_loads(s, load=pickle_load):
return load(BytesIO(s))
def parenthesize_alias(first, second):
return '%s (%s)' % (first, second) if first else second
class SerializerRegistry(object):
"""The registry keeps track of serialization methods."""
@ -163,10 +167,14 @@ class SerializerRegistry(object):
payload = encoder(data)
return content_type, content_encoding, payload
def decode(self, data, content_type, content_encoding, force=False):
if content_type in self._disabled_content_types and not force:
raise SerializerNotInstalled(
'Content-type %r has been disabled.' % (content_type, ))
def decode(self, data, content_type, content_encoding,
accept=None, force=False):
if accept is not None:
if content_type not in accept:
raise self._for_untrusted_content(content_type, 'untrusted')
else:
if content_type in self._disabled_content_types and not force:
raise self._for_untrusted_content(content_type, 'disabled')
content_type = content_type or 'application/data'
content_encoding = (content_encoding or 'utf-8').lower()
@ -179,13 +187,16 @@ class SerializerRegistry(object):
return _decode(data, content_encoding)
return data
def _for_untrusted_content(self, ctype, why):
return ContentDisallowed(
'Refusing to decode %(why)s content of type %(type)s' % {
'why': why,
'type': parenthesize_alias(self.type_to_name[ctype], ctype),
},
)
"""
.. data:: registry
Global registry of serializers/deserializers.
"""
#: Global registry of serializers/deserializers.
registry = SerializerRegistry()

View File

@ -47,12 +47,13 @@ class Message(object):
__slots__ = ('_state', 'channel', 'delivery_tag',
'content_type', 'content_encoding',
'delivery_info', 'headers', 'properties',
'body', '_decoded_cache', '__dict__')
'body', '_decoded_cache', 'accept', '__dict__')
MessageStateError = MessageStateError
def __init__(self, channel, body=None, delivery_tag=None,
content_type=None, content_encoding=None, delivery_info={},
properties=None, headers=None, postencode=None, **kwargs):
properties=None, headers=None, postencode=None,
accept=None, **kwargs):
self.channel = channel
self.delivery_tag = delivery_tag
self.content_type = content_type
@ -62,6 +63,7 @@ class Message(object):
self.properties = properties or {}
self._decoded_cache = None
self._state = 'RECEIVED'
self.accept = accept
try:
body = decompress(body, self.headers['compression'])
@ -142,7 +144,7 @@ class Message(object):
"""Deserialize the message body, returning the original
python structure sent by the publisher."""
return decode(self.body, self.content_type,
self.content_encoding)
self.content_encoding, accept=self.accept)
@property
def acknowledged(self):