Implement TTL for deferred jobs (#2111)

* Add new deferred_ttl attribute to jobs

* Add DeferredJobRegistry to cleanup

* Use normal ttl for deferred jobs as well

* Test that jobs landing in deferred queue get a TTL

* Pass pipeline in job cleanup

* Remove cleanup call

We pass a ttl of -1 which does not do anything

* Pass exc_info to add

The add implementation overwrites it so it won't get lost this way

* Remove extraneous save call

The add function already saves the job for us. So no need to save it
twice.

* Tune cleanup function description

* Test cleanup also works for deleted jobs

* Replace testconn with connection

---------

Co-authored-by: Raymond Guo <raymond.guo@databricks.com>
This commit is contained in:
Harm Berntsen 2024-08-11 14:23:28 +02:00 committed by GitHub
parent e80d4009ff
commit 193de26cff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 108 additions and 7 deletions

View File

@ -1569,7 +1569,7 @@ class Job:
registry = DeferredJobRegistry( registry = DeferredJobRegistry(
self.origin, connection=self.connection, job_class=self.__class__, serializer=self.serializer self.origin, connection=self.connection, job_class=self.__class__, serializer=self.serializer
) )
registry.add(self, pipeline=pipeline) registry.add(self, pipeline=pipeline, ttl=self.ttl)
connection = pipeline if pipeline is not None else self.connection connection = pipeline if pipeline is not None else self.connection

View File

@ -398,11 +398,43 @@ class DeferredJobRegistry(BaseRegistry):
key_template = 'rq:deferred:{0}' key_template = 'rq:deferred:{0}'
def cleanup(self): def cleanup(self, timestamp=None):
"""This method is only here to prevent errors because this method is """Remove expired jobs from registry and add them to FailedJobRegistry.
automatically called by `count()` and `get_job_ids()` methods Removes jobs with an expiry time earlier than timestamp, specified as
implemented in BaseRegistry.""" seconds since the Unix epoch. timestamp defaults to call time if
pass unspecified. Removed jobs are added to the failed job registry.
"""
score = timestamp if timestamp is not None else current_timestamp()
job_ids = self.get_expired_job_ids(score)
if job_ids:
failed_job_registry = FailedJobRegistry(self.name, self.connection, serializer=self.serializer)
with self.connection.pipeline() as pipeline:
for job_id in job_ids:
try:
job = self.job_class.fetch(job_id, connection=self.connection, serializer=self.serializer)
except NoSuchJobError:
continue
job.set_status(JobStatus.FAILED, pipeline=pipeline)
exc_info = "Expired in DeferredJobRegistry, moved to FailedJobRegistry at %s" % datetime.now()
failed_job_registry.add(job, job.failure_ttl, exc_info, pipeline, True)
pipeline.zremrangebyscore(self.key, 0, score)
pipeline.execute()
return job_ids
def add(self, job, ttl=None, pipeline=None, xx=False):
"""
Adds a job to a registry with expiry time of now + ttl.
Defaults to -1 (never expire).
"""
if ttl is None:
ttl = -1
return super(DeferredJobRegistry, self).add(job, ttl, pipeline, xx)
class ScheduledJobRegistry(BaseRegistry): class ScheduledJobRegistry(BaseRegistry):
@ -501,7 +533,7 @@ class CanceledJobRegistry(BaseRegistry):
def clean_registries(queue: 'Queue', exception_handlers: list = None): def clean_registries(queue: 'Queue', exception_handlers: list = None):
"""Cleans StartedJobRegistry, FinishedJobRegistry and FailedJobRegistry of a queue. """Cleans StartedJobRegistry, FinishedJobRegistry and FailedJobRegistry, and DeferredJobRegistry of a queue.
Args: Args:
queue (Queue): The queue to clean queue (Queue): The queue to clean
@ -517,3 +549,7 @@ def clean_registries(queue: 'Queue', exception_handlers: list = None):
FailedJobRegistry( FailedJobRegistry(
name=queue.name, connection=queue.connection, job_class=queue.job_class, serializer=queue.serializer name=queue.name, connection=queue.connection, job_class=queue.job_class, serializer=queue.serializer
).cleanup() ).cleanup()
DeferredJobRegistry(
name=queue.name, connection=queue.connection, job_class=queue.job_class, serializer=queue.serializer
).cleanup()

View File

@ -1,5 +1,6 @@
from rq import Queue, SimpleWorker, Worker from rq import Queue, SimpleWorker, Worker
from rq.job import Dependency, Job, JobStatus from rq.job import Dependency, Job, JobStatus
from rq.utils import current_timestamp
from tests import RQTestCase from tests import RQTestCase
from tests.fixtures import check_dependencies_are_met, div_by_zero, say_hello from tests.fixtures import check_dependencies_are_met, div_by_zero, say_hello
@ -154,6 +155,20 @@ class TestDependencies(RQTestCase):
self.assertEqual(parent_job.get_status(), JobStatus.FINISHED) self.assertEqual(parent_job.get_status(), JobStatus.FINISHED)
self.assertEqual(job.get_status(), JobStatus.FINISHED) self.assertEqual(job.get_status(), JobStatus.FINISHED)
def test_enqueue_job_dependency_sets_ttl(self):
"""Ensures that the TTL of jobs in the deferred queue is set"""
q = Queue(connection=self.connection)
parent_job = Job.create(say_hello, connection=self.connection)
parent_job.save()
timestamp = current_timestamp()
ttl = 5
job = Job.create(say_hello, connection=self.connection, depends_on=parent_job, ttl=ttl)
q.enqueue_job(job)
score = self.connection.zscore(q.deferred_job_registry.key, job.id)
self.assertLess(score, timestamp + ttl + 2)
self.assertGreater(score, timestamp + ttl - 2)
def test_dependencies_are_met_if_parent_is_canceled(self): def test_dependencies_are_met_if_parent_is_canceled(self):
"""When parent job is canceled, it should be treated as failed""" """When parent job is canceled, it should be treated as failed"""
queue = Queue(connection=self.connection) queue = Queue(connection=self.connection)

View File

@ -402,6 +402,24 @@ class TestDeferredRegistry(RQTestCase):
job_ids = [as_text(job_id) for job_id in self.connection.zrange(self.registry.key, 0, -1)] job_ids = [as_text(job_id) for job_id in self.connection.zrange(self.registry.key, 0, -1)]
self.assertEqual(job_ids, [job.id]) self.assertEqual(job_ids, [job.id])
def test_add_with_deferred_ttl(self):
"""Job TTL defaults to +inf"""
queue = Queue(connection=self.connection)
job = queue.enqueue(say_hello)
key = self.registry.key
self.registry.add(job)
score = self.connection.zscore(key, job.id)
self.assertEqual(score, float("inf"))
timestamp = current_timestamp()
ttl = 5
self.registry.add(job, ttl=ttl)
score = self.connection.zscore(key, job.id)
self.assertLess(score, timestamp + ttl + 2)
self.assertGreater(score, timestamp + ttl - 2)
def test_register_dependency(self): def test_register_dependency(self):
"""Ensure job creation and deletion works with DeferredJobRegistry.""" """Ensure job creation and deletion works with DeferredJobRegistry."""
queue = Queue(connection=self.connection) queue = Queue(connection=self.connection)
@ -415,6 +433,38 @@ class TestDeferredRegistry(RQTestCase):
job2.delete() job2.delete()
self.assertEqual(registry.get_job_ids(), []) self.assertEqual(registry.get_job_ids(), [])
def test_cleanup_supports_deleted_jobs(self):
queue = Queue(connection=self.connection)
job = queue.enqueue(say_hello)
self.registry.add(job, ttl=10)
self.assertEqual(self.registry.count, 1)
job.delete(remove_from_queue=False)
self.assertEqual(self.registry.count, 1)
self.registry.cleanup(current_timestamp() + 100)
self.assertEqual(self.registry.count, 0)
def test_cleanup_moves_jobs_to_failed_job_registry(self):
"""Moving expired jobs to FailedJobRegistry."""
queue = Queue(connection=self.connection)
failed_job_registry = FailedJobRegistry(connection=self.connection)
job = queue.enqueue(say_hello)
self.connection.zadd(self.registry.key, {job.id: 2})
# Job has not been moved to FailedJobRegistry
self.registry.cleanup(1)
self.assertNotIn(job, failed_job_registry)
self.assertIn(job, self.registry)
self.registry.cleanup()
self.assertIn(job.id, failed_job_registry)
self.assertNotIn(job, self.registry)
job.refresh()
self.assertEqual(job.get_status(), JobStatus.FAILED)
self.assertTrue(job.exc_info) # explanation is written to exc_info
class TestFailedJobRegistry(RQTestCase): class TestFailedJobRegistry(RQTestCase):
def test_default_failure_ttl(self): def test_default_failure_ttl(self):