From 2a704e35856bda908f5f122d1aaabebc6b469b0c Mon Sep 17 00:00:00 2001 From: Max Nikitenko Date: Sun, 7 Mar 2021 10:35:55 +0200 Subject: [PATCH] fix: non kombu json message decoding in SQS transport (#1306) * fix: non kombu json message decoding in SQS transport * fix: non kombu json message decoding in SQS transport - add tests Co-authored-by: Max Nikitenko --- kombu/transport/SQS.py | 16 +++++++++++++--- t/unit/transport/test_SQS.py | 11 +++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index d6c8b51f..caff36d3 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -370,11 +370,21 @@ class Channel(virtual.Channel): else: c.send_message(**kwargs) - def _message_to_python(self, message, queue_name, queue): + @staticmethod + def __b64_encoded(byte_string): try: - body = base64.b64decode(message['Body'].encode()) + return base64.b64encode(base64.b64decode(byte_string)) == byte_string + except Exception: # pylint: disable=broad-except + return False + + def _message_to_python(self, message, queue_name, queue): + body = message['Body'].encode() + try: + if self.__b64_encoded(body): + body = base64.b64decode(body) except TypeError: - body = message['Body'].encode() + pass + payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: queue = self._new_queue(queue_name) diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 12bc81e9..3302527e 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -4,8 +4,7 @@ NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from http://github.com/pcsforeducation/sqs-mock-python. They have been patched slightly. """ - - +import base64 import os import pytest import random @@ -330,6 +329,14 @@ class test_Channel: with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) + def test_is_base64_encoded(self): + raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \ + b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' + b64_enc = base64.b64encode(raw) + assert self.channel._Channel__b64_encoded(b64_enc) + assert not self.channel._Channel__b64_encoded(raw) + assert not self.channel._Channel__b64_encoded(b"test123") + def test_messages_to_python(self): from kombu.asynchronous.aws.sqs.message import Message