Add support for a callback on stopped jobs (#1909)

* Add support for a callback on stopped jobs

This function will run when an active job is stopped using the
send_stopped_job_command_function

* Remove testing async job with stopped callback

* Remove stopped job test from simpleworker case.

I can't stop the job from the test until the work() method returns, at which point the
job can't be stopped.

* Improve coverage

* Add test for stopped callback execution

* Move stopped callback check out of execution func

* Use SimpleWorker for stopped callback test

* Call stopped callback directly in main proc

* Remove unused imports

* Fix import order

* Fix import order

* Fix death penalty class arg

* Fix worker instance init

Sorry these commits are so lazy
This commit is contained in:
Ethan Wolinsky 2023-05-21 22:06:02 -04:00 committed by GitHub
parent b756cf82bd
commit 192fbc9c50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 148 additions and 2 deletions

View File

@ -30,6 +30,7 @@ class job: # noqa
retry: Optional['Retry'] = None,
on_failure: Optional[Callable[..., Any]] = None,
on_success: Optional[Callable[..., Any]] = None,
on_stopped: Optional[Callable[..., Any]] = None,
):
"""A decorator that adds a ``delay`` method to the decorated function,
which in turn creates a RQ job when called. Accepts a required
@ -60,6 +61,7 @@ class job: # noqa
retry (Optional[Retry], optional): A Retry object. Defaults to None.
on_failure (Optional[Callable[..., Any]], optional): Callable to run on failure. Defaults to None.
on_success (Optional[Callable[..., Any]], optional): Callable to run on success. Defaults to None.
on_stopped (Optional[Callable[..., Any]], optional): Callable to run when stopped. Defaults to None.
"""
self.queue = queue
self.queue_class = backend_class(self, 'queue_class', override=queue_class)
@ -75,6 +77,7 @@ class job: # noqa
self.retry = retry
self.on_success = on_success
self.on_failure = on_failure
self.on_stopped = on_stopped
def __call__(self, f):
@wraps(f)
@ -110,6 +113,7 @@ class job: # noqa
retry=self.retry,
on_failure=self.on_failure,
on_success=self.on_success,
on_stopped=self.on_stopped,
)
f.delay = delay

View File

@ -160,6 +160,7 @@ class Job:
*,
on_success: Optional[Union['Callback', Callable[..., Any]]] = None,
on_failure: Optional[Union['Callback', Callable[..., Any]]] = None,
on_stopped: Optional[Union['Callback', Callable[..., Any]]] = None,
) -> 'Job':
"""Creates a new Job instance for the given function, arguments, and
keyword arguments.
@ -196,6 +197,8 @@ class Job:
when/if the Job finishes sucessfully. Defaults to None.
on_failure (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run
when/if the Job fails. Defaults to None.
on_stopped (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run
when/if the Job is stopped. Defaults to None.
Raises:
TypeError: If `args` is not a tuple/list
@ -203,6 +206,7 @@ class Job:
TypeError: If the `func` is something other than a string or a Callable reference
ValueError: If `on_failure` is not a function
ValueError: If `on_success` is not a function
ValueError: If `on_stopped` is not a function
Returns:
Job: A job instance.
@ -259,6 +263,15 @@ class Job:
job._failure_callback_name = on_failure.name
job._failure_callback_timeout = on_failure.timeout
if on_stopped:
if not isinstance(on_stopped, Callback):
warnings.warn(
'Passing a `Callable` `on_stopped` is deprecated, pass `Callback` instead', DeprecationWarning
)
on_stopped = Callback(on_stopped) # backward compatibility
job._stopped_callback_name = on_stopped.name
job._stopped_callback_timeout = on_stopped.timeout
# Extra meta data
job.description = description or job.get_call_string()
job.result_ttl = parse_timeout(result_ttl)
@ -442,6 +455,23 @@ class Job:
return self._failure_callback_timeout
@property
def stopped_callback(self):
if self._stopped_callback is UNEVALUATED:
if self._stopped_callback_name:
self._stopped_callback = import_attribute(self._stopped_callback_name)
else:
self._stopped_callback = None
return self._stopped_callback
@property
def stopped_callback_timeout(self) -> int:
if self._stopped_callback_timeout is None:
return CALLBACK_TIMEOUT
return self._stopped_callback_timeout
def _deserialize_data(self):
"""Deserializes the Job `data` into a tuple.
This includes the `_func_name`, `_instance`, `_args` and `_kwargs`
@ -607,6 +637,8 @@ class Job:
self._success_callback = UNEVALUATED
self._failure_callback_name = None
self._failure_callback = UNEVALUATED
self._stopped_callback_name = None
self._stopped_callback = UNEVALUATED
self.description: Optional[str] = None
self.origin: str = ''
self.enqueued_at: Optional[datetime] = None
@ -617,6 +649,7 @@ class Job:
self.timeout: Optional[float] = None
self._success_callback_timeout: Optional[int] = None
self._failure_callback_timeout: Optional[int] = None
self._stopped_callback_timeout: Optional[int] = None
self.result_ttl: Optional[int] = None
self.failure_ttl: Optional[int] = None
self.ttl: Optional[int] = None
@ -913,6 +946,12 @@ class Job:
if 'failure_callback_timeout' in obj:
self._failure_callback_timeout = int(obj.get('failure_callback_timeout'))
if obj.get('stopped_callback_name'):
self._stopped_callback_name = obj.get('stopped_callback_name').decode()
if 'stopped_callback_timeout' in obj:
self._stopped_callback_timeout = int(obj.get('stopped_callback_timeout'))
dep_ids = obj.get('dependency_ids')
dep_id = obj.get('dependency_id') # for backwards compatibility
self._dependency_ids = json.loads(dep_ids.decode()) if dep_ids else [dep_id.decode()] if dep_id else []
@ -967,6 +1006,7 @@ class Job:
'data': zlib.compress(self.data),
'success_callback_name': self._success_callback_name if self._success_callback_name else '',
'failure_callback_name': self._failure_callback_name if self._failure_callback_name else '',
'stopped_callback_name': self._stopped_callback_name if self._stopped_callback_name else '',
'started_at': utcformat(self.started_at) if self.started_at else '',
'ended_at': utcformat(self.ended_at) if self.ended_at else '',
'last_heartbeat': utcformat(self.last_heartbeat) if self.last_heartbeat else '',
@ -997,6 +1037,8 @@ class Job:
obj['success_callback_timeout'] = self._success_callback_timeout
if self._failure_callback_timeout is not None:
obj['failure_callback_timeout'] = self._failure_callback_timeout
if self._stopped_callback_timeout is not None:
obj['stopped_callback_timeout'] = self._stopped_callback_timeout
if self.result_ttl is not None:
obj['result_ttl'] = self.result_ttl
if self.failure_ttl is not None:
@ -1386,6 +1428,16 @@ class Job:
logger.exception(f'Job {self.id}: error while executing failure callback')
raise
def execute_stopped_callback(self, death_penalty_class: Type[BaseDeathPenalty]):
"""Executes stopped_callback with possible timeout"""
logger.debug('Running stopped callbacks for %s', self.id)
try:
with death_penalty_class(self.stopped_callback_timeout, JobTimeoutException, job_id=self.id):
self.stopped_callback(self, self.connection)
except Exception: # noqa
logger.exception(f'Job {self.id}: error while executing stopped callback')
raise
def _handle_success(self, result_ttl: int, pipeline: 'Pipeline'):
"""Saves and cleanup job after successful execution"""
# self.log.debug('Setting job %s status to finished', job.id)

View File

@ -50,6 +50,7 @@ class EnqueueData(
"retry",
"on_success",
"on_failure",
"on_stopped",
],
)
):
@ -516,6 +517,7 @@ class Queue:
*,
on_success: Optional[Callable] = None,
on_failure: Optional[Callable] = None,
on_stopped: Optional[Callable] = None,
) -> Job:
"""Creates a job based on parameters given
@ -535,6 +537,7 @@ class Queue:
retry (Optional[Retry], optional): The Retry Object. Defaults to None.
on_success (Optional[Callable], optional): On success callable. Defaults to None.
on_failure (Optional[Callable], optional): On failure callable. Defaults to None.
on_stopped (Optional[Callable], optional): On stopped callable. Defaults to None.
Raises:
ValueError: If the timeout is 0
@ -575,6 +578,7 @@ class Queue:
serializer=self.serializer,
on_success=on_success,
on_failure=on_failure,
on_stopped=on_stopped,
)
if retry:
@ -657,6 +661,7 @@ class Queue:
retry: Optional['Retry'] = None,
on_success: Optional[Callable[..., Any]] = None,
on_failure: Optional[Callable[..., Any]] = None,
on_stopped: Optional[Callable[..., Any]] = None,
pipeline: Optional['Pipeline'] = None,
) -> Job:
"""Creates a job to represent the delayed function call and enqueues it.
@ -681,6 +686,7 @@ class Queue:
retry (Optional[Retry], optional): Retry object. Defaults to None.
on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None.
on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None.
on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None.
pipeline (Optional[Pipeline], optional): The Redis Pipeline. Defaults to None.
Returns:
@ -703,6 +709,7 @@ class Queue:
retry=retry,
on_success=on_success,
on_failure=on_failure,
on_stopped=on_stopped,
)
return self.enqueue_job(job, pipeline=pipeline, at_front=at_front)
@ -723,6 +730,7 @@ class Queue:
retry: Optional['Retry'] = None,
on_success: Optional[Callable] = None,
on_failure: Optional[Callable] = None,
on_stopped: Optional[Callable] = None,
) -> EnqueueData:
"""Need this till support dropped for python_version < 3.7, where defaults can be specified for named tuples
And can keep this logic within EnqueueData
@ -743,6 +751,7 @@ class Queue:
retry (Optional[Retry], optional): Retry object. Defaults to None.
on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None.
on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None.
on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None.
Returns:
EnqueueData: The EnqueueData
@ -763,6 +772,7 @@ class Queue:
retry,
on_success,
on_failure,
on_stopped,
)
def enqueue_many(self, job_datas: List['EnqueueData'], pipeline: Optional['Pipeline'] = None) -> List[Job]:
@ -889,6 +899,7 @@ class Queue:
retry = kwargs.pop('retry', None)
on_success = kwargs.pop('on_success', None)
on_failure = kwargs.pop('on_failure', None)
on_stopped = kwargs.pop('on_stopped', None)
pipeline = kwargs.pop('pipeline', None)
if 'args' in kwargs or 'kwargs' in kwargs:
@ -910,6 +921,7 @@ class Queue:
retry,
on_success,
on_failure,
on_stopped,
pipeline,
args,
kwargs,
@ -941,6 +953,7 @@ class Queue:
retry,
on_success,
on_failure,
on_stopped,
pipeline,
args,
kwargs,
@ -962,6 +975,7 @@ class Queue:
retry=retry,
on_success=on_success,
on_failure=on_failure,
on_stopped=on_stopped,
pipeline=pipeline,
)
@ -989,6 +1003,7 @@ class Queue:
retry,
on_success,
on_failure,
on_stopped,
pipeline,
args,
kwargs,
@ -1009,6 +1024,7 @@ class Queue:
retry=retry,
on_success=on_success,
on_failure=on_failure,
on_stopped=on_stopped,
)
if at_front:
job.enqueue_at_front = True

View File

@ -1191,6 +1191,8 @@ class Worker(BaseWorker):
elif self._stopped_job_id == job.id:
# Work-horse killed deliberately
self.log.warning('Job stopped by user, moving job to FailedJobRegistry')
if job.stopped_callback:
job.execute_stopped_callback(self.death_penalty_class)
self.handle_job_failure(job, queue=queue, exc_string='Job stopped by user, work-horse terminated.')
elif job_status not in [JobStatus.FINISHED, JobStatus.FAILED]:
if not job.ended_at:

View File

@ -59,6 +59,11 @@ def div_by_zero(x):
return x / 0
def long_process():
time.sleep(60)
return
def some_calculation(x, y, z=1):
"""Some arbitrary calculation with three numbers. Choose z smartly if you
want a division by zero exception.
@ -287,6 +292,10 @@ def save_exception(job, connection, type, value, traceback):
connection.set('failure_callback:%s' % job.id, str(value), ex=60)
def save_result_if_not_stopped(job, connection, result=""):
connection.set('stopped_callback:%s' % job.id, result, ex=60)
def erroneous_callback(job):
"""A callback that's not written properly"""
pass

View File

@ -1,10 +1,19 @@
from datetime import timedelta
from rq import Queue, Worker
from rq.job import UNEVALUATED, Job, JobStatus
from rq.job import Job, JobStatus, UNEVALUATED
from rq.serializers import JSONSerializer
from rq.worker import SimpleWorker
from tests import RQTestCase
from tests.fixtures import div_by_zero, erroneous_callback, save_exception, save_result, say_hello
from tests.fixtures import (
div_by_zero,
erroneous_callback,
long_process,
save_exception,
save_result,
save_result_if_not_stopped,
say_hello,
)
class QueueCallbackTestCase(RQTestCase):
@ -44,6 +53,24 @@ class QueueCallbackTestCase(RQTestCase):
job = Job.fetch(id=job.id, connection=self.testconn)
self.assertEqual(job.failure_callback, print)
def test_enqueue_with_stopped_callback(self):
"""queue.enqueue* methods with on_stopped is persisted correctly"""
queue = Queue(connection=self.testconn)
# Only functions and builtins are supported as callback
with self.assertRaises(ValueError):
queue.enqueue(say_hello, on_stopped=Job.fetch)
job = queue.enqueue(long_process, on_stopped=print)
job = Job.fetch(id=job.id, connection=self.testconn)
self.assertEqual(job.stopped_callback, print)
job = queue.enqueue_in(timedelta(seconds=10), long_process, on_stopped=print)
job = Job.fetch(id=job.id, connection=self.testconn)
self.assertEqual(job.stopped_callback, print)
class SyncJobCallback(RQTestCase):
def test_success_callback(self):
@ -70,6 +97,17 @@ class SyncJobCallback(RQTestCase):
self.assertEqual(job.get_status(), JobStatus.FAILED)
self.assertFalse(self.testconn.exists('failure_callback:%s' % job.id))
def test_stopped_callback(self):
"""queue.enqueue* methods with on_stopped is persisted correctly"""
connection = self.testconn
queue = Queue('foo', connection=connection, serializer=JSONSerializer)
worker = SimpleWorker('foo', connection=connection, serializer=JSONSerializer)
job = queue.enqueue(long_process, on_stopped=save_result_if_not_stopped)
job.execute_stopped_callback(
worker.death_penalty_class
) # Calling execute_stopped_callback directly for coverage
self.assertTrue(self.testconn.exists('stopped_callback:%s' % job.id))
class WorkerCallbackTestCase(RQTestCase):
def test_success_callback(self):
@ -159,3 +197,22 @@ class JobCallbackTestCase(RQTestCase):
job = Job.fetch(id=job.id, connection=self.testconn)
self.assertEqual(job.failure_callback, print)
def test_job_creation_with_stopped_callback(self):
"""Ensure stopped callbacks are persisted properly"""
job = Job.create(say_hello)
self.assertIsNone(job._stopped_callback_name)
# _failure_callback starts with UNEVALUATED
self.assertEqual(job._stopped_callback, UNEVALUATED)
self.assertEqual(job.stopped_callback, None)
# _stopped_callback becomes `None` after `job.stopped_callback` is called if there's no stopped callback
self.assertEqual(job._stopped_callback, None)
# job.failure_callback is assigned properly
job = Job.create(say_hello, on_stopped=print)
self.assertIsNotNone(job._stopped_callback_name)
self.assertEqual(job.stopped_callback, print)
job.save()
job = Job.fetch(id=job.id, connection=self.testconn)
self.assertEqual(job.stopped_callback, print)

View File

@ -223,6 +223,7 @@ class TestJob(RQTestCase):
b'worker_name',
b'success_callback_name',
b'failure_callback_name',
b'stopped_callback_name',
},
set(self.testconn.hkeys(job.key)),
)
@ -260,6 +261,7 @@ class TestJob(RQTestCase):
func=fixtures.some_calculation,
on_success=Callback(fixtures.say_hello, timeout=10),
on_failure=fixtures.say_pid,
on_stopped=fixtures.say_hello,
) # deprecated callable
job.save()
stored_job = Job.fetch(job.id)
@ -267,7 +269,9 @@ class TestJob(RQTestCase):
self.assertEqual(fixtures.say_hello, stored_job.success_callback)
self.assertEqual(10, stored_job.success_callback_timeout)
self.assertEqual(fixtures.say_pid, stored_job.failure_callback)
self.assertEqual(fixtures.say_hello, stored_job.stopped_callback)
self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout)
self.assertEqual(CALLBACK_TIMEOUT, stored_job.stopped_callback_timeout)
# None(s)
job = Job.create(func=fixtures.some_calculation, on_failure=None)
@ -279,6 +283,8 @@ class TestJob(RQTestCase):
self.assertIsNone(stored_job.failure_callback)
self.assertEqual(CALLBACK_TIMEOUT, job.failure_callback_timeout) # timeout should be never none
self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout)
self.assertEqual(CALLBACK_TIMEOUT, job.stopped_callback_timeout) # timeout should be never none
self.assertIsNone(stored_job.stopped_callback)
def test_store_then_fetch(self):
"""Store, then fetch."""