mirror of https://github.com/celery/kombu.git
Adds Consumer.accept: A whitelist of content_types/serializer names to accept
This commit is contained in:
parent
b1a8d2f868
commit
24e441c79c
|
@ -63,6 +63,11 @@ class SerializerNotInstalled(KombuError):
|
|||
pass
|
||||
|
||||
|
||||
class ContentDisallowed(SerializerNotInstalled):
|
||||
"""Consumer does not allow this content-type."""
|
||||
pass
|
||||
|
||||
|
||||
class InconsistencyError(StdConnectionError):
|
||||
"""Data or environment has been found to be inconsistent,
|
||||
depending on the cause it may be possible to retry the operation."""
|
||||
|
|
|
@ -318,16 +318,27 @@ class Consumer(object):
|
|||
#: that occurred while trying to decode it.
|
||||
on_decode_error = None
|
||||
|
||||
#: List of accepted content-types.
|
||||
#:
|
||||
#: An exception will be raised if the consumer receives
|
||||
#: a message with an untrusted content type.
|
||||
#: By default all content-types are accepted, but not if
|
||||
#: :func:`kombu.disable_untrusted_serializers` was called,
|
||||
#: in which case only json is allowed.
|
||||
accept = None
|
||||
|
||||
_next_tag = count(1).next # global
|
||||
|
||||
def __init__(self, channel, queues=None, no_ack=None, auto_declare=None,
|
||||
callbacks=None, on_decode_error=None, on_message=None):
|
||||
callbacks=None, on_decode_error=None, on_message=None,
|
||||
accept=None):
|
||||
self.channel = channel
|
||||
self.queues = self.queues or [] if queues is None else queues
|
||||
self.no_ack = self.no_ack if no_ack is None else no_ack
|
||||
self.callbacks = (self.callbacks or [] if callbacks is None
|
||||
else callbacks)
|
||||
self.on_message = on_message
|
||||
self.accept = accept
|
||||
self._active_tags = {}
|
||||
if auto_declare is not None:
|
||||
self.auto_declare = auto_declare
|
||||
|
@ -528,6 +539,9 @@ class Consumer(object):
|
|||
return tag
|
||||
|
||||
def _receive_callback(self, message):
|
||||
accept = self.accept
|
||||
if accept is not None:
|
||||
message.accept = accept
|
||||
on_m, channel, decoded = self.on_message, self.channel, None
|
||||
try:
|
||||
m2p = getattr(channel, 'message_to_python', None)
|
||||
|
|
|
@ -18,7 +18,7 @@ try:
|
|||
except ImportError: # pragma: no cover
|
||||
cpickle = None # noqa
|
||||
|
||||
from .exceptions import SerializerNotInstalled
|
||||
from .exceptions import SerializerNotInstalled, ContentDisallowed
|
||||
from .utils import entrypoints
|
||||
from .utils.encoding import str_to_bytes, bytes_t
|
||||
|
||||
|
@ -74,6 +74,10 @@ def pickle_loads(s, load=pickle_load):
|
|||
return load(BytesIO(s))
|
||||
|
||||
|
||||
def parenthesize_alias(first, second):
|
||||
return '%s (%s)' % (first, second) if first else second
|
||||
|
||||
|
||||
class SerializerRegistry(object):
|
||||
"""The registry keeps track of serialization methods."""
|
||||
|
||||
|
@ -163,10 +167,14 @@ class SerializerRegistry(object):
|
|||
payload = encoder(data)
|
||||
return content_type, content_encoding, payload
|
||||
|
||||
def decode(self, data, content_type, content_encoding, force=False):
|
||||
if content_type in self._disabled_content_types and not force:
|
||||
raise SerializerNotInstalled(
|
||||
'Content-type %r has been disabled.' % (content_type, ))
|
||||
def decode(self, data, content_type, content_encoding,
|
||||
accept=None, force=False):
|
||||
if accept is not None:
|
||||
if content_type not in accept:
|
||||
raise self._for_untrusted_content(content_type, 'untrusted')
|
||||
else:
|
||||
if content_type in self._disabled_content_types and not force:
|
||||
raise self._for_untrusted_content(content_type, 'disabled')
|
||||
content_type = content_type or 'application/data'
|
||||
content_encoding = (content_encoding or 'utf-8').lower()
|
||||
|
||||
|
@ -179,13 +187,16 @@ class SerializerRegistry(object):
|
|||
return _decode(data, content_encoding)
|
||||
return data
|
||||
|
||||
def _for_untrusted_content(self, ctype, why):
|
||||
return ContentDisallowed(
|
||||
'Refusing to decode %(why)s content of type %(type)s' % {
|
||||
'why': why,
|
||||
'type': parenthesize_alias(self.type_to_name[ctype], ctype),
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
.. data:: registry
|
||||
|
||||
Global registry of serializers/deserializers.
|
||||
|
||||
"""
|
||||
#: Global registry of serializers/deserializers.
|
||||
registry = SerializerRegistry()
|
||||
|
||||
|
||||
|
|
|
@ -47,12 +47,13 @@ class Message(object):
|
|||
__slots__ = ('_state', 'channel', 'delivery_tag',
|
||||
'content_type', 'content_encoding',
|
||||
'delivery_info', 'headers', 'properties',
|
||||
'body', '_decoded_cache', '__dict__')
|
||||
'body', '_decoded_cache', 'accept', '__dict__')
|
||||
MessageStateError = MessageStateError
|
||||
|
||||
def __init__(self, channel, body=None, delivery_tag=None,
|
||||
content_type=None, content_encoding=None, delivery_info={},
|
||||
properties=None, headers=None, postencode=None, **kwargs):
|
||||
properties=None, headers=None, postencode=None,
|
||||
accept=None, **kwargs):
|
||||
self.channel = channel
|
||||
self.delivery_tag = delivery_tag
|
||||
self.content_type = content_type
|
||||
|
@ -62,6 +63,7 @@ class Message(object):
|
|||
self.properties = properties or {}
|
||||
self._decoded_cache = None
|
||||
self._state = 'RECEIVED'
|
||||
self.accept = accept
|
||||
|
||||
try:
|
||||
body = decompress(body, self.headers['compression'])
|
||||
|
@ -142,7 +144,7 @@ class Message(object):
|
|||
"""Deserialize the message body, returning the original
|
||||
python structure sent by the publisher."""
|
||||
return decode(self.body, self.content_type,
|
||||
self.content_encoding)
|
||||
self.content_encoding, accept=self.accept)
|
||||
|
||||
@property
|
||||
def acknowledged(self):
|
||||
|
|
Loading…
Reference in New Issue