mirror of https://github.com/rq/rq.git
Add support for a callback on stopped jobs (#1909)
* Add support for a callback on stopped jobs This function will run when an active job is stopped using the send_stopped_job_command_function * Remove testing async job with stopped callback * Remove stopped job test from simpleworker case. I can't stop the job from the test until the work() method returns, at which point the job can't be stopped. * Improve coverage * Add test for stopped callback execution * Move stopped callback check out of execution func * Use SimpleWorker for stopped callback test * Call stopped callback directly in main proc * Remove unused imports * Fix import order * Fix import order * Fix death penalty class arg * Fix worker instance init Sorry these commits are so lazy
This commit is contained in:
parent
b756cf82bd
commit
192fbc9c50
|
@ -30,6 +30,7 @@ class job: # noqa
|
|||
retry: Optional['Retry'] = None,
|
||||
on_failure: Optional[Callable[..., Any]] = None,
|
||||
on_success: Optional[Callable[..., Any]] = None,
|
||||
on_stopped: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
"""A decorator that adds a ``delay`` method to the decorated function,
|
||||
which in turn creates a RQ job when called. Accepts a required
|
||||
|
@ -60,6 +61,7 @@ class job: # noqa
|
|||
retry (Optional[Retry], optional): A Retry object. Defaults to None.
|
||||
on_failure (Optional[Callable[..., Any]], optional): Callable to run on failure. Defaults to None.
|
||||
on_success (Optional[Callable[..., Any]], optional): Callable to run on success. Defaults to None.
|
||||
on_stopped (Optional[Callable[..., Any]], optional): Callable to run when stopped. Defaults to None.
|
||||
"""
|
||||
self.queue = queue
|
||||
self.queue_class = backend_class(self, 'queue_class', override=queue_class)
|
||||
|
@ -75,6 +77,7 @@ class job: # noqa
|
|||
self.retry = retry
|
||||
self.on_success = on_success
|
||||
self.on_failure = on_failure
|
||||
self.on_stopped = on_stopped
|
||||
|
||||
def __call__(self, f):
|
||||
@wraps(f)
|
||||
|
@ -110,6 +113,7 @@ class job: # noqa
|
|||
retry=self.retry,
|
||||
on_failure=self.on_failure,
|
||||
on_success=self.on_success,
|
||||
on_stopped=self.on_stopped,
|
||||
)
|
||||
|
||||
f.delay = delay
|
||||
|
|
52
rq/job.py
52
rq/job.py
|
@ -160,6 +160,7 @@ class Job:
|
|||
*,
|
||||
on_success: Optional[Union['Callback', Callable[..., Any]]] = None,
|
||||
on_failure: Optional[Union['Callback', Callable[..., Any]]] = None,
|
||||
on_stopped: Optional[Union['Callback', Callable[..., Any]]] = None,
|
||||
) -> 'Job':
|
||||
"""Creates a new Job instance for the given function, arguments, and
|
||||
keyword arguments.
|
||||
|
@ -196,6 +197,8 @@ class Job:
|
|||
when/if the Job finishes sucessfully. Defaults to None.
|
||||
on_failure (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run
|
||||
when/if the Job fails. Defaults to None.
|
||||
on_stopped (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run
|
||||
when/if the Job is stopped. Defaults to None.
|
||||
|
||||
Raises:
|
||||
TypeError: If `args` is not a tuple/list
|
||||
|
@ -203,6 +206,7 @@ class Job:
|
|||
TypeError: If the `func` is something other than a string or a Callable reference
|
||||
ValueError: If `on_failure` is not a function
|
||||
ValueError: If `on_success` is not a function
|
||||
ValueError: If `on_stopped` is not a function
|
||||
|
||||
Returns:
|
||||
Job: A job instance.
|
||||
|
@ -259,6 +263,15 @@ class Job:
|
|||
job._failure_callback_name = on_failure.name
|
||||
job._failure_callback_timeout = on_failure.timeout
|
||||
|
||||
if on_stopped:
|
||||
if not isinstance(on_stopped, Callback):
|
||||
warnings.warn(
|
||||
'Passing a `Callable` `on_stopped` is deprecated, pass `Callback` instead', DeprecationWarning
|
||||
)
|
||||
on_stopped = Callback(on_stopped) # backward compatibility
|
||||
job._stopped_callback_name = on_stopped.name
|
||||
job._stopped_callback_timeout = on_stopped.timeout
|
||||
|
||||
# Extra meta data
|
||||
job.description = description or job.get_call_string()
|
||||
job.result_ttl = parse_timeout(result_ttl)
|
||||
|
@ -442,6 +455,23 @@ class Job:
|
|||
|
||||
return self._failure_callback_timeout
|
||||
|
||||
@property
|
||||
def stopped_callback(self):
|
||||
if self._stopped_callback is UNEVALUATED:
|
||||
if self._stopped_callback_name:
|
||||
self._stopped_callback = import_attribute(self._stopped_callback_name)
|
||||
else:
|
||||
self._stopped_callback = None
|
||||
|
||||
return self._stopped_callback
|
||||
|
||||
@property
|
||||
def stopped_callback_timeout(self) -> int:
|
||||
if self._stopped_callback_timeout is None:
|
||||
return CALLBACK_TIMEOUT
|
||||
|
||||
return self._stopped_callback_timeout
|
||||
|
||||
def _deserialize_data(self):
|
||||
"""Deserializes the Job `data` into a tuple.
|
||||
This includes the `_func_name`, `_instance`, `_args` and `_kwargs`
|
||||
|
@ -607,6 +637,8 @@ class Job:
|
|||
self._success_callback = UNEVALUATED
|
||||
self._failure_callback_name = None
|
||||
self._failure_callback = UNEVALUATED
|
||||
self._stopped_callback_name = None
|
||||
self._stopped_callback = UNEVALUATED
|
||||
self.description: Optional[str] = None
|
||||
self.origin: str = ''
|
||||
self.enqueued_at: Optional[datetime] = None
|
||||
|
@ -617,6 +649,7 @@ class Job:
|
|||
self.timeout: Optional[float] = None
|
||||
self._success_callback_timeout: Optional[int] = None
|
||||
self._failure_callback_timeout: Optional[int] = None
|
||||
self._stopped_callback_timeout: Optional[int] = None
|
||||
self.result_ttl: Optional[int] = None
|
||||
self.failure_ttl: Optional[int] = None
|
||||
self.ttl: Optional[int] = None
|
||||
|
@ -913,6 +946,12 @@ class Job:
|
|||
if 'failure_callback_timeout' in obj:
|
||||
self._failure_callback_timeout = int(obj.get('failure_callback_timeout'))
|
||||
|
||||
if obj.get('stopped_callback_name'):
|
||||
self._stopped_callback_name = obj.get('stopped_callback_name').decode()
|
||||
|
||||
if 'stopped_callback_timeout' in obj:
|
||||
self._stopped_callback_timeout = int(obj.get('stopped_callback_timeout'))
|
||||
|
||||
dep_ids = obj.get('dependency_ids')
|
||||
dep_id = obj.get('dependency_id') # for backwards compatibility
|
||||
self._dependency_ids = json.loads(dep_ids.decode()) if dep_ids else [dep_id.decode()] if dep_id else []
|
||||
|
@ -967,6 +1006,7 @@ class Job:
|
|||
'data': zlib.compress(self.data),
|
||||
'success_callback_name': self._success_callback_name if self._success_callback_name else '',
|
||||
'failure_callback_name': self._failure_callback_name if self._failure_callback_name else '',
|
||||
'stopped_callback_name': self._stopped_callback_name if self._stopped_callback_name else '',
|
||||
'started_at': utcformat(self.started_at) if self.started_at else '',
|
||||
'ended_at': utcformat(self.ended_at) if self.ended_at else '',
|
||||
'last_heartbeat': utcformat(self.last_heartbeat) if self.last_heartbeat else '',
|
||||
|
@ -997,6 +1037,8 @@ class Job:
|
|||
obj['success_callback_timeout'] = self._success_callback_timeout
|
||||
if self._failure_callback_timeout is not None:
|
||||
obj['failure_callback_timeout'] = self._failure_callback_timeout
|
||||
if self._stopped_callback_timeout is not None:
|
||||
obj['stopped_callback_timeout'] = self._stopped_callback_timeout
|
||||
if self.result_ttl is not None:
|
||||
obj['result_ttl'] = self.result_ttl
|
||||
if self.failure_ttl is not None:
|
||||
|
@ -1386,6 +1428,16 @@ class Job:
|
|||
logger.exception(f'Job {self.id}: error while executing failure callback')
|
||||
raise
|
||||
|
||||
def execute_stopped_callback(self, death_penalty_class: Type[BaseDeathPenalty]):
|
||||
"""Executes stopped_callback with possible timeout"""
|
||||
logger.debug('Running stopped callbacks for %s', self.id)
|
||||
try:
|
||||
with death_penalty_class(self.stopped_callback_timeout, JobTimeoutException, job_id=self.id):
|
||||
self.stopped_callback(self, self.connection)
|
||||
except Exception: # noqa
|
||||
logger.exception(f'Job {self.id}: error while executing stopped callback')
|
||||
raise
|
||||
|
||||
def _handle_success(self, result_ttl: int, pipeline: 'Pipeline'):
|
||||
"""Saves and cleanup job after successful execution"""
|
||||
# self.log.debug('Setting job %s status to finished', job.id)
|
||||
|
|
16
rq/queue.py
16
rq/queue.py
|
@ -50,6 +50,7 @@ class EnqueueData(
|
|||
"retry",
|
||||
"on_success",
|
||||
"on_failure",
|
||||
"on_stopped",
|
||||
],
|
||||
)
|
||||
):
|
||||
|
@ -516,6 +517,7 @@ class Queue:
|
|||
*,
|
||||
on_success: Optional[Callable] = None,
|
||||
on_failure: Optional[Callable] = None,
|
||||
on_stopped: Optional[Callable] = None,
|
||||
) -> Job:
|
||||
"""Creates a job based on parameters given
|
||||
|
||||
|
@ -535,6 +537,7 @@ class Queue:
|
|||
retry (Optional[Retry], optional): The Retry Object. Defaults to None.
|
||||
on_success (Optional[Callable], optional): On success callable. Defaults to None.
|
||||
on_failure (Optional[Callable], optional): On failure callable. Defaults to None.
|
||||
on_stopped (Optional[Callable], optional): On stopped callable. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the timeout is 0
|
||||
|
@ -575,6 +578,7 @@ class Queue:
|
|||
serializer=self.serializer,
|
||||
on_success=on_success,
|
||||
on_failure=on_failure,
|
||||
on_stopped=on_stopped,
|
||||
)
|
||||
|
||||
if retry:
|
||||
|
@ -657,6 +661,7 @@ class Queue:
|
|||
retry: Optional['Retry'] = None,
|
||||
on_success: Optional[Callable[..., Any]] = None,
|
||||
on_failure: Optional[Callable[..., Any]] = None,
|
||||
on_stopped: Optional[Callable[..., Any]] = None,
|
||||
pipeline: Optional['Pipeline'] = None,
|
||||
) -> Job:
|
||||
"""Creates a job to represent the delayed function call and enqueues it.
|
||||
|
@ -681,6 +686,7 @@ class Queue:
|
|||
retry (Optional[Retry], optional): Retry object. Defaults to None.
|
||||
on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None.
|
||||
on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None.
|
||||
on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None.
|
||||
pipeline (Optional[Pipeline], optional): The Redis Pipeline. Defaults to None.
|
||||
|
||||
Returns:
|
||||
|
@ -703,6 +709,7 @@ class Queue:
|
|||
retry=retry,
|
||||
on_success=on_success,
|
||||
on_failure=on_failure,
|
||||
on_stopped=on_stopped,
|
||||
)
|
||||
return self.enqueue_job(job, pipeline=pipeline, at_front=at_front)
|
||||
|
||||
|
@ -723,6 +730,7 @@ class Queue:
|
|||
retry: Optional['Retry'] = None,
|
||||
on_success: Optional[Callable] = None,
|
||||
on_failure: Optional[Callable] = None,
|
||||
on_stopped: Optional[Callable] = None,
|
||||
) -> EnqueueData:
|
||||
"""Need this till support dropped for python_version < 3.7, where defaults can be specified for named tuples
|
||||
And can keep this logic within EnqueueData
|
||||
|
@ -743,6 +751,7 @@ class Queue:
|
|||
retry (Optional[Retry], optional): Retry object. Defaults to None.
|
||||
on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None.
|
||||
on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None.
|
||||
on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None.
|
||||
|
||||
Returns:
|
||||
EnqueueData: The EnqueueData
|
||||
|
@ -763,6 +772,7 @@ class Queue:
|
|||
retry,
|
||||
on_success,
|
||||
on_failure,
|
||||
on_stopped,
|
||||
)
|
||||
|
||||
def enqueue_many(self, job_datas: List['EnqueueData'], pipeline: Optional['Pipeline'] = None) -> List[Job]:
|
||||
|
@ -889,6 +899,7 @@ class Queue:
|
|||
retry = kwargs.pop('retry', None)
|
||||
on_success = kwargs.pop('on_success', None)
|
||||
on_failure = kwargs.pop('on_failure', None)
|
||||
on_stopped = kwargs.pop('on_stopped', None)
|
||||
pipeline = kwargs.pop('pipeline', None)
|
||||
|
||||
if 'args' in kwargs or 'kwargs' in kwargs:
|
||||
|
@ -910,6 +921,7 @@ class Queue:
|
|||
retry,
|
||||
on_success,
|
||||
on_failure,
|
||||
on_stopped,
|
||||
pipeline,
|
||||
args,
|
||||
kwargs,
|
||||
|
@ -941,6 +953,7 @@ class Queue:
|
|||
retry,
|
||||
on_success,
|
||||
on_failure,
|
||||
on_stopped,
|
||||
pipeline,
|
||||
args,
|
||||
kwargs,
|
||||
|
@ -962,6 +975,7 @@ class Queue:
|
|||
retry=retry,
|
||||
on_success=on_success,
|
||||
on_failure=on_failure,
|
||||
on_stopped=on_stopped,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
|
||||
|
@ -989,6 +1003,7 @@ class Queue:
|
|||
retry,
|
||||
on_success,
|
||||
on_failure,
|
||||
on_stopped,
|
||||
pipeline,
|
||||
args,
|
||||
kwargs,
|
||||
|
@ -1009,6 +1024,7 @@ class Queue:
|
|||
retry=retry,
|
||||
on_success=on_success,
|
||||
on_failure=on_failure,
|
||||
on_stopped=on_stopped,
|
||||
)
|
||||
if at_front:
|
||||
job.enqueue_at_front = True
|
||||
|
|
|
@ -1191,6 +1191,8 @@ class Worker(BaseWorker):
|
|||
elif self._stopped_job_id == job.id:
|
||||
# Work-horse killed deliberately
|
||||
self.log.warning('Job stopped by user, moving job to FailedJobRegistry')
|
||||
if job.stopped_callback:
|
||||
job.execute_stopped_callback(self.death_penalty_class)
|
||||
self.handle_job_failure(job, queue=queue, exc_string='Job stopped by user, work-horse terminated.')
|
||||
elif job_status not in [JobStatus.FINISHED, JobStatus.FAILED]:
|
||||
if not job.ended_at:
|
||||
|
|
|
@ -59,6 +59,11 @@ def div_by_zero(x):
|
|||
return x / 0
|
||||
|
||||
|
||||
def long_process():
|
||||
time.sleep(60)
|
||||
return
|
||||
|
||||
|
||||
def some_calculation(x, y, z=1):
|
||||
"""Some arbitrary calculation with three numbers. Choose z smartly if you
|
||||
want a division by zero exception.
|
||||
|
@ -287,6 +292,10 @@ def save_exception(job, connection, type, value, traceback):
|
|||
connection.set('failure_callback:%s' % job.id, str(value), ex=60)
|
||||
|
||||
|
||||
def save_result_if_not_stopped(job, connection, result=""):
|
||||
connection.set('stopped_callback:%s' % job.id, result, ex=60)
|
||||
|
||||
|
||||
def erroneous_callback(job):
|
||||
"""A callback that's not written properly"""
|
||||
pass
|
||||
|
|
|
@ -1,10 +1,19 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from rq import Queue, Worker
|
||||
from rq.job import UNEVALUATED, Job, JobStatus
|
||||
from rq.job import Job, JobStatus, UNEVALUATED
|
||||
from rq.serializers import JSONSerializer
|
||||
from rq.worker import SimpleWorker
|
||||
from tests import RQTestCase
|
||||
from tests.fixtures import div_by_zero, erroneous_callback, save_exception, save_result, say_hello
|
||||
from tests.fixtures import (
|
||||
div_by_zero,
|
||||
erroneous_callback,
|
||||
long_process,
|
||||
save_exception,
|
||||
save_result,
|
||||
save_result_if_not_stopped,
|
||||
say_hello,
|
||||
)
|
||||
|
||||
|
||||
class QueueCallbackTestCase(RQTestCase):
|
||||
|
@ -44,6 +53,24 @@ class QueueCallbackTestCase(RQTestCase):
|
|||
job = Job.fetch(id=job.id, connection=self.testconn)
|
||||
self.assertEqual(job.failure_callback, print)
|
||||
|
||||
def test_enqueue_with_stopped_callback(self):
|
||||
"""queue.enqueue* methods with on_stopped is persisted correctly"""
|
||||
queue = Queue(connection=self.testconn)
|
||||
|
||||
# Only functions and builtins are supported as callback
|
||||
with self.assertRaises(ValueError):
|
||||
queue.enqueue(say_hello, on_stopped=Job.fetch)
|
||||
|
||||
job = queue.enqueue(long_process, on_stopped=print)
|
||||
|
||||
job = Job.fetch(id=job.id, connection=self.testconn)
|
||||
self.assertEqual(job.stopped_callback, print)
|
||||
|
||||
job = queue.enqueue_in(timedelta(seconds=10), long_process, on_stopped=print)
|
||||
|
||||
job = Job.fetch(id=job.id, connection=self.testconn)
|
||||
self.assertEqual(job.stopped_callback, print)
|
||||
|
||||
|
||||
class SyncJobCallback(RQTestCase):
|
||||
def test_success_callback(self):
|
||||
|
@ -70,6 +97,17 @@ class SyncJobCallback(RQTestCase):
|
|||
self.assertEqual(job.get_status(), JobStatus.FAILED)
|
||||
self.assertFalse(self.testconn.exists('failure_callback:%s' % job.id))
|
||||
|
||||
def test_stopped_callback(self):
|
||||
"""queue.enqueue* methods with on_stopped is persisted correctly"""
|
||||
connection = self.testconn
|
||||
queue = Queue('foo', connection=connection, serializer=JSONSerializer)
|
||||
worker = SimpleWorker('foo', connection=connection, serializer=JSONSerializer)
|
||||
job = queue.enqueue(long_process, on_stopped=save_result_if_not_stopped)
|
||||
job.execute_stopped_callback(
|
||||
worker.death_penalty_class
|
||||
) # Calling execute_stopped_callback directly for coverage
|
||||
self.assertTrue(self.testconn.exists('stopped_callback:%s' % job.id))
|
||||
|
||||
|
||||
class WorkerCallbackTestCase(RQTestCase):
|
||||
def test_success_callback(self):
|
||||
|
@ -159,3 +197,22 @@ class JobCallbackTestCase(RQTestCase):
|
|||
|
||||
job = Job.fetch(id=job.id, connection=self.testconn)
|
||||
self.assertEqual(job.failure_callback, print)
|
||||
|
||||
def test_job_creation_with_stopped_callback(self):
|
||||
"""Ensure stopped callbacks are persisted properly"""
|
||||
job = Job.create(say_hello)
|
||||
self.assertIsNone(job._stopped_callback_name)
|
||||
# _failure_callback starts with UNEVALUATED
|
||||
self.assertEqual(job._stopped_callback, UNEVALUATED)
|
||||
self.assertEqual(job.stopped_callback, None)
|
||||
# _stopped_callback becomes `None` after `job.stopped_callback` is called if there's no stopped callback
|
||||
self.assertEqual(job._stopped_callback, None)
|
||||
|
||||
# job.failure_callback is assigned properly
|
||||
job = Job.create(say_hello, on_stopped=print)
|
||||
self.assertIsNotNone(job._stopped_callback_name)
|
||||
self.assertEqual(job.stopped_callback, print)
|
||||
job.save()
|
||||
|
||||
job = Job.fetch(id=job.id, connection=self.testconn)
|
||||
self.assertEqual(job.stopped_callback, print)
|
||||
|
|
|
@ -223,6 +223,7 @@ class TestJob(RQTestCase):
|
|||
b'worker_name',
|
||||
b'success_callback_name',
|
||||
b'failure_callback_name',
|
||||
b'stopped_callback_name',
|
||||
},
|
||||
set(self.testconn.hkeys(job.key)),
|
||||
)
|
||||
|
@ -260,6 +261,7 @@ class TestJob(RQTestCase):
|
|||
func=fixtures.some_calculation,
|
||||
on_success=Callback(fixtures.say_hello, timeout=10),
|
||||
on_failure=fixtures.say_pid,
|
||||
on_stopped=fixtures.say_hello,
|
||||
) # deprecated callable
|
||||
job.save()
|
||||
stored_job = Job.fetch(job.id)
|
||||
|
@ -267,7 +269,9 @@ class TestJob(RQTestCase):
|
|||
self.assertEqual(fixtures.say_hello, stored_job.success_callback)
|
||||
self.assertEqual(10, stored_job.success_callback_timeout)
|
||||
self.assertEqual(fixtures.say_pid, stored_job.failure_callback)
|
||||
self.assertEqual(fixtures.say_hello, stored_job.stopped_callback)
|
||||
self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout)
|
||||
self.assertEqual(CALLBACK_TIMEOUT, stored_job.stopped_callback_timeout)
|
||||
|
||||
# None(s)
|
||||
job = Job.create(func=fixtures.some_calculation, on_failure=None)
|
||||
|
@ -279,6 +283,8 @@ class TestJob(RQTestCase):
|
|||
self.assertIsNone(stored_job.failure_callback)
|
||||
self.assertEqual(CALLBACK_TIMEOUT, job.failure_callback_timeout) # timeout should be never none
|
||||
self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout)
|
||||
self.assertEqual(CALLBACK_TIMEOUT, job.stopped_callback_timeout) # timeout should be never none
|
||||
self.assertIsNone(stored_job.stopped_callback)
|
||||
|
||||
def test_store_then_fetch(self):
|
||||
"""Store, then fetch."""
|
||||
|
|
Loading…
Reference in New Issue