diff --git a/rq/decorators.py b/rq/decorators.py index 41b2da46..e15e3ffe 100644 --- a/rq/decorators.py +++ b/rq/decorators.py @@ -30,6 +30,7 @@ class job: # noqa retry: Optional['Retry'] = None, on_failure: Optional[Callable[..., Any]] = None, on_success: Optional[Callable[..., Any]] = None, + on_stopped: Optional[Callable[..., Any]] = None, ): """A decorator that adds a ``delay`` method to the decorated function, which in turn creates a RQ job when called. Accepts a required @@ -60,6 +61,7 @@ class job: # noqa retry (Optional[Retry], optional): A Retry object. Defaults to None. on_failure (Optional[Callable[..., Any]], optional): Callable to run on failure. Defaults to None. on_success (Optional[Callable[..., Any]], optional): Callable to run on success. Defaults to None. + on_stopped (Optional[Callable[..., Any]], optional): Callable to run when stopped. Defaults to None. """ self.queue = queue self.queue_class = backend_class(self, 'queue_class', override=queue_class) @@ -75,6 +77,7 @@ class job: # noqa self.retry = retry self.on_success = on_success self.on_failure = on_failure + self.on_stopped = on_stopped def __call__(self, f): @wraps(f) @@ -110,6 +113,7 @@ class job: # noqa retry=self.retry, on_failure=self.on_failure, on_success=self.on_success, + on_stopped=self.on_stopped, ) f.delay = delay diff --git a/rq/job.py b/rq/job.py index 17a80790..4fa712d1 100644 --- a/rq/job.py +++ b/rq/job.py @@ -160,6 +160,7 @@ class Job: *, on_success: Optional[Union['Callback', Callable[..., Any]]] = None, on_failure: Optional[Union['Callback', Callable[..., Any]]] = None, + on_stopped: Optional[Union['Callback', Callable[..., Any]]] = None, ) -> 'Job': """Creates a new Job instance for the given function, arguments, and keyword arguments. @@ -196,6 +197,8 @@ class Job: when/if the Job finishes sucessfully. Defaults to None. on_failure (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run when/if the Job fails. Defaults to None. + on_stopped (Optional[Callable[..., Any]], optional): A callback function, should be a callable to run + when/if the Job is stopped. Defaults to None. Raises: TypeError: If `args` is not a tuple/list @@ -203,6 +206,7 @@ class Job: TypeError: If the `func` is something other than a string or a Callable reference ValueError: If `on_failure` is not a function ValueError: If `on_success` is not a function + ValueError: If `on_stopped` is not a function Returns: Job: A job instance. @@ -259,6 +263,15 @@ class Job: job._failure_callback_name = on_failure.name job._failure_callback_timeout = on_failure.timeout + if on_stopped: + if not isinstance(on_stopped, Callback): + warnings.warn( + 'Passing a `Callable` `on_stopped` is deprecated, pass `Callback` instead', DeprecationWarning + ) + on_stopped = Callback(on_stopped) # backward compatibility + job._stopped_callback_name = on_stopped.name + job._stopped_callback_timeout = on_stopped.timeout + # Extra meta data job.description = description or job.get_call_string() job.result_ttl = parse_timeout(result_ttl) @@ -442,6 +455,23 @@ class Job: return self._failure_callback_timeout + @property + def stopped_callback(self): + if self._stopped_callback is UNEVALUATED: + if self._stopped_callback_name: + self._stopped_callback = import_attribute(self._stopped_callback_name) + else: + self._stopped_callback = None + + return self._stopped_callback + + @property + def stopped_callback_timeout(self) -> int: + if self._stopped_callback_timeout is None: + return CALLBACK_TIMEOUT + + return self._stopped_callback_timeout + def _deserialize_data(self): """Deserializes the Job `data` into a tuple. This includes the `_func_name`, `_instance`, `_args` and `_kwargs` @@ -607,6 +637,8 @@ class Job: self._success_callback = UNEVALUATED self._failure_callback_name = None self._failure_callback = UNEVALUATED + self._stopped_callback_name = None + self._stopped_callback = UNEVALUATED self.description: Optional[str] = None self.origin: str = '' self.enqueued_at: Optional[datetime] = None @@ -617,6 +649,7 @@ class Job: self.timeout: Optional[float] = None self._success_callback_timeout: Optional[int] = None self._failure_callback_timeout: Optional[int] = None + self._stopped_callback_timeout: Optional[int] = None self.result_ttl: Optional[int] = None self.failure_ttl: Optional[int] = None self.ttl: Optional[int] = None @@ -913,6 +946,12 @@ class Job: if 'failure_callback_timeout' in obj: self._failure_callback_timeout = int(obj.get('failure_callback_timeout')) + if obj.get('stopped_callback_name'): + self._stopped_callback_name = obj.get('stopped_callback_name').decode() + + if 'stopped_callback_timeout' in obj: + self._stopped_callback_timeout = int(obj.get('stopped_callback_timeout')) + dep_ids = obj.get('dependency_ids') dep_id = obj.get('dependency_id') # for backwards compatibility self._dependency_ids = json.loads(dep_ids.decode()) if dep_ids else [dep_id.decode()] if dep_id else [] @@ -967,6 +1006,7 @@ class Job: 'data': zlib.compress(self.data), 'success_callback_name': self._success_callback_name if self._success_callback_name else '', 'failure_callback_name': self._failure_callback_name if self._failure_callback_name else '', + 'stopped_callback_name': self._stopped_callback_name if self._stopped_callback_name else '', 'started_at': utcformat(self.started_at) if self.started_at else '', 'ended_at': utcformat(self.ended_at) if self.ended_at else '', 'last_heartbeat': utcformat(self.last_heartbeat) if self.last_heartbeat else '', @@ -997,6 +1037,8 @@ class Job: obj['success_callback_timeout'] = self._success_callback_timeout if self._failure_callback_timeout is not None: obj['failure_callback_timeout'] = self._failure_callback_timeout + if self._stopped_callback_timeout is not None: + obj['stopped_callback_timeout'] = self._stopped_callback_timeout if self.result_ttl is not None: obj['result_ttl'] = self.result_ttl if self.failure_ttl is not None: @@ -1386,6 +1428,16 @@ class Job: logger.exception(f'Job {self.id}: error while executing failure callback') raise + def execute_stopped_callback(self, death_penalty_class: Type[BaseDeathPenalty]): + """Executes stopped_callback with possible timeout""" + logger.debug('Running stopped callbacks for %s', self.id) + try: + with death_penalty_class(self.stopped_callback_timeout, JobTimeoutException, job_id=self.id): + self.stopped_callback(self, self.connection) + except Exception: # noqa + logger.exception(f'Job {self.id}: error while executing stopped callback') + raise + def _handle_success(self, result_ttl: int, pipeline: 'Pipeline'): """Saves and cleanup job after successful execution""" # self.log.debug('Setting job %s status to finished', job.id) diff --git a/rq/queue.py b/rq/queue.py index 1709bae4..d415f84a 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -50,6 +50,7 @@ class EnqueueData( "retry", "on_success", "on_failure", + "on_stopped", ], ) ): @@ -516,6 +517,7 @@ class Queue: *, on_success: Optional[Callable] = None, on_failure: Optional[Callable] = None, + on_stopped: Optional[Callable] = None, ) -> Job: """Creates a job based on parameters given @@ -535,6 +537,7 @@ class Queue: retry (Optional[Retry], optional): The Retry Object. Defaults to None. on_success (Optional[Callable], optional): On success callable. Defaults to None. on_failure (Optional[Callable], optional): On failure callable. Defaults to None. + on_stopped (Optional[Callable], optional): On stopped callable. Defaults to None. Raises: ValueError: If the timeout is 0 @@ -575,6 +578,7 @@ class Queue: serializer=self.serializer, on_success=on_success, on_failure=on_failure, + on_stopped=on_stopped, ) if retry: @@ -657,6 +661,7 @@ class Queue: retry: Optional['Retry'] = None, on_success: Optional[Callable[..., Any]] = None, on_failure: Optional[Callable[..., Any]] = None, + on_stopped: Optional[Callable[..., Any]] = None, pipeline: Optional['Pipeline'] = None, ) -> Job: """Creates a job to represent the delayed function call and enqueues it. @@ -681,6 +686,7 @@ class Queue: retry (Optional[Retry], optional): Retry object. Defaults to None. on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None. on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None. + on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None. pipeline (Optional[Pipeline], optional): The Redis Pipeline. Defaults to None. Returns: @@ -703,6 +709,7 @@ class Queue: retry=retry, on_success=on_success, on_failure=on_failure, + on_stopped=on_stopped, ) return self.enqueue_job(job, pipeline=pipeline, at_front=at_front) @@ -723,6 +730,7 @@ class Queue: retry: Optional['Retry'] = None, on_success: Optional[Callable] = None, on_failure: Optional[Callable] = None, + on_stopped: Optional[Callable] = None, ) -> EnqueueData: """Need this till support dropped for python_version < 3.7, where defaults can be specified for named tuples And can keep this logic within EnqueueData @@ -743,6 +751,7 @@ class Queue: retry (Optional[Retry], optional): Retry object. Defaults to None. on_success (Optional[Callable[..., Any]], optional): Callable for on success. Defaults to None. on_failure (Optional[Callable[..., Any]], optional): Callable for on failure. Defaults to None. + on_stopped (Optional[Callable[..., Any]], optional): Callable for on stopped. Defaults to None. Returns: EnqueueData: The EnqueueData @@ -763,6 +772,7 @@ class Queue: retry, on_success, on_failure, + on_stopped, ) def enqueue_many(self, job_datas: List['EnqueueData'], pipeline: Optional['Pipeline'] = None) -> List[Job]: @@ -889,6 +899,7 @@ class Queue: retry = kwargs.pop('retry', None) on_success = kwargs.pop('on_success', None) on_failure = kwargs.pop('on_failure', None) + on_stopped = kwargs.pop('on_stopped', None) pipeline = kwargs.pop('pipeline', None) if 'args' in kwargs or 'kwargs' in kwargs: @@ -910,6 +921,7 @@ class Queue: retry, on_success, on_failure, + on_stopped, pipeline, args, kwargs, @@ -941,6 +953,7 @@ class Queue: retry, on_success, on_failure, + on_stopped, pipeline, args, kwargs, @@ -962,6 +975,7 @@ class Queue: retry=retry, on_success=on_success, on_failure=on_failure, + on_stopped=on_stopped, pipeline=pipeline, ) @@ -989,6 +1003,7 @@ class Queue: retry, on_success, on_failure, + on_stopped, pipeline, args, kwargs, @@ -1009,6 +1024,7 @@ class Queue: retry=retry, on_success=on_success, on_failure=on_failure, + on_stopped=on_stopped, ) if at_front: job.enqueue_at_front = True diff --git a/rq/worker.py b/rq/worker.py index a5572780..4c38adfb 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -1191,6 +1191,8 @@ class Worker(BaseWorker): elif self._stopped_job_id == job.id: # Work-horse killed deliberately self.log.warning('Job stopped by user, moving job to FailedJobRegistry') + if job.stopped_callback: + job.execute_stopped_callback(self.death_penalty_class) self.handle_job_failure(job, queue=queue, exc_string='Job stopped by user, work-horse terminated.') elif job_status not in [JobStatus.FINISHED, JobStatus.FAILED]: if not job.ended_at: diff --git a/tests/fixtures.py b/tests/fixtures.py index 62ea8e1f..f0831ee6 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -59,6 +59,11 @@ def div_by_zero(x): return x / 0 +def long_process(): + time.sleep(60) + return + + def some_calculation(x, y, z=1): """Some arbitrary calculation with three numbers. Choose z smartly if you want a division by zero exception. @@ -287,6 +292,10 @@ def save_exception(job, connection, type, value, traceback): connection.set('failure_callback:%s' % job.id, str(value), ex=60) +def save_result_if_not_stopped(job, connection, result=""): + connection.set('stopped_callback:%s' % job.id, result, ex=60) + + def erroneous_callback(job): """A callback that's not written properly""" pass diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index c47ad84c..03cef246 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,10 +1,19 @@ from datetime import timedelta from rq import Queue, Worker -from rq.job import UNEVALUATED, Job, JobStatus +from rq.job import Job, JobStatus, UNEVALUATED +from rq.serializers import JSONSerializer from rq.worker import SimpleWorker from tests import RQTestCase -from tests.fixtures import div_by_zero, erroneous_callback, save_exception, save_result, say_hello +from tests.fixtures import ( + div_by_zero, + erroneous_callback, + long_process, + save_exception, + save_result, + save_result_if_not_stopped, + say_hello, +) class QueueCallbackTestCase(RQTestCase): @@ -44,6 +53,24 @@ class QueueCallbackTestCase(RQTestCase): job = Job.fetch(id=job.id, connection=self.testconn) self.assertEqual(job.failure_callback, print) + def test_enqueue_with_stopped_callback(self): + """queue.enqueue* methods with on_stopped is persisted correctly""" + queue = Queue(connection=self.testconn) + + # Only functions and builtins are supported as callback + with self.assertRaises(ValueError): + queue.enqueue(say_hello, on_stopped=Job.fetch) + + job = queue.enqueue(long_process, on_stopped=print) + + job = Job.fetch(id=job.id, connection=self.testconn) + self.assertEqual(job.stopped_callback, print) + + job = queue.enqueue_in(timedelta(seconds=10), long_process, on_stopped=print) + + job = Job.fetch(id=job.id, connection=self.testconn) + self.assertEqual(job.stopped_callback, print) + class SyncJobCallback(RQTestCase): def test_success_callback(self): @@ -70,6 +97,17 @@ class SyncJobCallback(RQTestCase): self.assertEqual(job.get_status(), JobStatus.FAILED) self.assertFalse(self.testconn.exists('failure_callback:%s' % job.id)) + def test_stopped_callback(self): + """queue.enqueue* methods with on_stopped is persisted correctly""" + connection = self.testconn + queue = Queue('foo', connection=connection, serializer=JSONSerializer) + worker = SimpleWorker('foo', connection=connection, serializer=JSONSerializer) + job = queue.enqueue(long_process, on_stopped=save_result_if_not_stopped) + job.execute_stopped_callback( + worker.death_penalty_class + ) # Calling execute_stopped_callback directly for coverage + self.assertTrue(self.testconn.exists('stopped_callback:%s' % job.id)) + class WorkerCallbackTestCase(RQTestCase): def test_success_callback(self): @@ -159,3 +197,22 @@ class JobCallbackTestCase(RQTestCase): job = Job.fetch(id=job.id, connection=self.testconn) self.assertEqual(job.failure_callback, print) + + def test_job_creation_with_stopped_callback(self): + """Ensure stopped callbacks are persisted properly""" + job = Job.create(say_hello) + self.assertIsNone(job._stopped_callback_name) + # _failure_callback starts with UNEVALUATED + self.assertEqual(job._stopped_callback, UNEVALUATED) + self.assertEqual(job.stopped_callback, None) + # _stopped_callback becomes `None` after `job.stopped_callback` is called if there's no stopped callback + self.assertEqual(job._stopped_callback, None) + + # job.failure_callback is assigned properly + job = Job.create(say_hello, on_stopped=print) + self.assertIsNotNone(job._stopped_callback_name) + self.assertEqual(job.stopped_callback, print) + job.save() + + job = Job.fetch(id=job.id, connection=self.testconn) + self.assertEqual(job.stopped_callback, print) diff --git a/tests/test_job.py b/tests/test_job.py index f855fa89..29c309f2 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -223,6 +223,7 @@ class TestJob(RQTestCase): b'worker_name', b'success_callback_name', b'failure_callback_name', + b'stopped_callback_name', }, set(self.testconn.hkeys(job.key)), ) @@ -260,6 +261,7 @@ class TestJob(RQTestCase): func=fixtures.some_calculation, on_success=Callback(fixtures.say_hello, timeout=10), on_failure=fixtures.say_pid, + on_stopped=fixtures.say_hello, ) # deprecated callable job.save() stored_job = Job.fetch(job.id) @@ -267,7 +269,9 @@ class TestJob(RQTestCase): self.assertEqual(fixtures.say_hello, stored_job.success_callback) self.assertEqual(10, stored_job.success_callback_timeout) self.assertEqual(fixtures.say_pid, stored_job.failure_callback) + self.assertEqual(fixtures.say_hello, stored_job.stopped_callback) self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout) + self.assertEqual(CALLBACK_TIMEOUT, stored_job.stopped_callback_timeout) # None(s) job = Job.create(func=fixtures.some_calculation, on_failure=None) @@ -279,6 +283,8 @@ class TestJob(RQTestCase): self.assertIsNone(stored_job.failure_callback) self.assertEqual(CALLBACK_TIMEOUT, job.failure_callback_timeout) # timeout should be never none self.assertEqual(CALLBACK_TIMEOUT, stored_job.failure_callback_timeout) + self.assertEqual(CALLBACK_TIMEOUT, job.stopped_callback_timeout) # timeout should be never none + self.assertIsNone(stored_job.stopped_callback) def test_store_then_fetch(self): """Store, then fetch."""