From 8f1de37a08318d388a048341ddca12af30019bf3 Mon Sep 17 00:00:00 2001 From: Antoine Busque Date: Wed, 29 Apr 2020 17:41:36 -0400 Subject: [PATCH] Fix: make SQLAlchemy Channel init thread-safe The init method for the SQLAlchemy transport's Channel implementation is not currently thread-safe. This means that if two threads in a process attempt to instantiate a Channel concurrently, there can be a race condition when registering the SQLAlchemy model classes, which leaves the SQLAlchemy ORM mapper in a broken state. This results in an exception like the following: ``` sqlalchemy.exc.InvalidRequestError: Table 'kombu_message' is already defined for this MetaData instance. Specify 'extend_existing=True' to redefine options and columns on an existing Table object. ``` Any subsequent calls to the SQLAlchemy ORM will then fail with an exception like this: ``` sqlalchemy.exc.InvalidRequestError: Multiple classes found for path "Message" in the registry of this declarative base. Please use a fully module-qualified path. ``` This also applies to any SQLAlchemy calls in the same process, not just those made by Kombu or Celery, and can only really be recovered from by killing the process and starting over. To avoid all of this, introduce a mutex which ensures that only one thread at a time can register a SQLAlchemy model when instantiating the Channel class. Signed-off-by: Antoine Busque --- kombu/transport/sqlalchemy/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/kombu/transport/sqlalchemy/__init__.py b/kombu/transport/sqlalchemy/__init__.py index bcf9ed5b..fe469fae 100644 --- a/kombu/transport/sqlalchemy/__init__.py +++ b/kombu/transport/sqlalchemy/__init__.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, unicode_literals +import threading from json import loads, dumps from sqlalchemy import create_engine @@ -21,6 +22,8 @@ from .models import (ModelBase, Queue as QueueBase, Message as MessageBase, VERSION = (1, 1, 0) __version__ = '.'.join(map(str, VERSION)) +_MUTEX = threading.Lock() + class Channel(virtual.Channel): """The channel class.""" @@ -127,9 +130,10 @@ class Channel(virtual.Channel): return self._query_all(queue).count() def _declarative_cls(self, name, base, ns): - if name in class_registry: - return class_registry[name] - return type(str(name), (base, ModelBase), ns) + with _MUTEX: + if name in class_registry: + return class_registry[name] + return type(str(name), (base, ModelBase), ns) @cached_property def queue_cls(self):