diff --git a/kombu/connection.py b/kombu/connection.py index 50dbe10d..0c9779b5 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -482,7 +482,7 @@ class Connection: def ensure(self, obj, fun, errback=None, max_retries=None, interval_start=1, interval_step=1, interval_max=1, - on_revive=None): + on_revive=None, retry_errors=None): """Ensure operation completes. Regardless of any channel/connection errors occurring. @@ -511,6 +511,9 @@ class Connection: each retry. on_revive (Callable): Optional callback called whenever revival completes successfully + retry_errors (tuple): Optional list of errors to retry on + regardless of the connection state. Must provide max_retries + if this is specified. Examples: >>> from kombu import Connection, Producer @@ -525,6 +528,15 @@ class Connection: ... errback=errback, max_retries=3) >>> publish({'hello': 'world'}, routing_key='dest') """ + if retry_errors is None: + retry_errors = tuple() + elif max_retries is None: + # If the retry_errors is specified, but max_retries is not, + # this could lead into an infinite loop potentially. + raise ValueError( + "max_retries must be specified if retry_errors is specified" + ) + def _ensured(*args, **kwargs): got_connection = 0 conn_errors = self.recoverable_connection_errors @@ -536,6 +548,11 @@ class Connection: for retries in count(0): # for infinity try: return fun(*args, **kwargs) + except retry_errors as exc: + if max_retries is not None and retries >= max_retries: + raise + self._debug('ensure retry policy error: %r', + exc, exc_info=1) except conn_errors as exc: if got_connection and not has_modern_errors: # transport can not distinguish between diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 740bd6dc..c2daee3b 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -497,6 +497,43 @@ class test_Connection: with pytest.raises(OperationalError): ensured() + def test_ensure_retry_errors_is_not_looping_infinitely(self): + class _MessageNacked(Exception): + pass + + def publish(): + raise _MessageNacked('NACK') + + with pytest.raises(ValueError): + self.conn.ensure( + self.conn, + publish, + retry_errors=(_MessageNacked,) + ) + + def test_ensure_retry_errors_is_limited_by_max_retries(self): + class _MessageNacked(Exception): + pass + + tries = 0 + + def publish(): + nonlocal tries + tries += 1 + if tries <= 3: + raise _MessageNacked('NACK') + # On the 4th try, we let it pass + return 'ACK' + + ensured = self.conn.ensure( + self.conn, + publish, + max_retries=3, # 3 retries + 1 initial try = 4 tries + retry_errors=(_MessageNacked,) + ) + + assert ensured() == 'ACK' + def test_autoretry(self): myfun = Mock()