diff --git a/kombu/exceptions.py b/kombu/exceptions.py index 914dca5b..32261506 100644 --- a/kombu/exceptions.py +++ b/kombu/exceptions.py @@ -63,6 +63,11 @@ class SerializerNotInstalled(KombuError): pass +class ContentDisallowed(SerializerNotInstalled): + """Consumer does not allow this content-type.""" + pass + + class InconsistencyError(StdConnectionError): """Data or environment has been found to be inconsistent, depending on the cause it may be possible to retry the operation.""" diff --git a/kombu/messaging.py b/kombu/messaging.py index cb13f2c7..20069c37 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -318,16 +318,27 @@ class Consumer(object): #: that occurred while trying to decode it. on_decode_error = None + #: List of accepted content-types. + #: + #: An exception will be raised if the consumer receives + #: a message with an untrusted content type. + #: By default all content-types are accepted, but not if + #: :func:`kombu.disable_untrusted_serializers` was called, + #: in which case only json is allowed. + accept = None + _next_tag = count(1).next # global def __init__(self, channel, queues=None, no_ack=None, auto_declare=None, - callbacks=None, on_decode_error=None, on_message=None): + callbacks=None, on_decode_error=None, on_message=None, + accept=None): self.channel = channel self.queues = self.queues or [] if queues is None else queues self.no_ack = self.no_ack if no_ack is None else no_ack self.callbacks = (self.callbacks or [] if callbacks is None else callbacks) self.on_message = on_message + self.accept = accept self._active_tags = {} if auto_declare is not None: self.auto_declare = auto_declare @@ -528,6 +539,9 @@ class Consumer(object): return tag def _receive_callback(self, message): + accept = self.accept + if accept is not None: + message.accept = accept on_m, channel, decoded = self.on_message, self.channel, None try: m2p = getattr(channel, 'message_to_python', None) diff --git a/kombu/serialization.py b/kombu/serialization.py index 3393f1fb..0e190d6e 100644 --- a/kombu/serialization.py +++ b/kombu/serialization.py @@ -18,7 +18,7 @@ try: except ImportError: # pragma: no cover cpickle = None # noqa -from .exceptions import SerializerNotInstalled +from .exceptions import SerializerNotInstalled, ContentDisallowed from .utils import entrypoints from .utils.encoding import str_to_bytes, bytes_t @@ -74,6 +74,10 @@ def pickle_loads(s, load=pickle_load): return load(BytesIO(s)) +def parenthesize_alias(first, second): + return '%s (%s)' % (first, second) if first else second + + class SerializerRegistry(object): """The registry keeps track of serialization methods.""" @@ -163,10 +167,14 @@ class SerializerRegistry(object): payload = encoder(data) return content_type, content_encoding, payload - def decode(self, data, content_type, content_encoding, force=False): - if content_type in self._disabled_content_types and not force: - raise SerializerNotInstalled( - 'Content-type %r has been disabled.' % (content_type, )) + def decode(self, data, content_type, content_encoding, + accept=None, force=False): + if accept is not None: + if content_type not in accept: + raise self._for_untrusted_content(content_type, 'untrusted') + else: + if content_type in self._disabled_content_types and not force: + raise self._for_untrusted_content(content_type, 'disabled') content_type = content_type or 'application/data' content_encoding = (content_encoding or 'utf-8').lower() @@ -179,13 +187,16 @@ class SerializerRegistry(object): return _decode(data, content_encoding) return data + def _for_untrusted_content(self, ctype, why): + return ContentDisallowed( + 'Refusing to decode %(why)s content of type %(type)s' % { + 'why': why, + 'type': parenthesize_alias(self.type_to_name[ctype], ctype), + }, + ) -""" -.. data:: registry -Global registry of serializers/deserializers. - -""" +#: Global registry of serializers/deserializers. registry = SerializerRegistry() diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 917c6198..7a2a6e91 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -47,12 +47,13 @@ class Message(object): __slots__ = ('_state', 'channel', 'delivery_tag', 'content_type', 'content_encoding', 'delivery_info', 'headers', 'properties', - 'body', '_decoded_cache', '__dict__') + 'body', '_decoded_cache', 'accept', '__dict__') MessageStateError = MessageStateError def __init__(self, channel, body=None, delivery_tag=None, content_type=None, content_encoding=None, delivery_info={}, - properties=None, headers=None, postencode=None, **kwargs): + properties=None, headers=None, postencode=None, + accept=None, **kwargs): self.channel = channel self.delivery_tag = delivery_tag self.content_type = content_type @@ -62,6 +63,7 @@ class Message(object): self.properties = properties or {} self._decoded_cache = None self._state = 'RECEIVED' + self.accept = accept try: body = decompress(body, self.headers['compression']) @@ -142,7 +144,7 @@ class Message(object): """Deserialize the message body, returning the original python structure sent by the publisher.""" return decode(self.body, self.content_type, - self.content_encoding) + self.content_encoding, accept=self.accept) @property def acknowledged(self):