add: params for all callbacks

This commit is contained in:
Kenneth Goh 2023-10-15 17:30:04 +08:00
parent 56a71ac351
commit f91cd733cc
1 changed files with 18 additions and 4 deletions

View File

@ -265,6 +265,7 @@ class Job:
on_failure = Callback(on_failure) # backward compatibility
job._failure_callback_name = on_failure.name
job._failure_callback_timeout = on_failure.timeout
job._success_callback_params = on_failure.params
if on_stopped:
if not isinstance(on_stopped, Callback):
@ -275,6 +276,7 @@ class Job:
on_stopped = Callback(on_stopped) # backward compatibility
job._stopped_callback_name = on_stopped.name
job._stopped_callback_timeout = on_stopped.timeout
job._success_callback_params = on_stopped.params
# Extra meta data
job.description = description or job.get_call_string()
@ -641,8 +643,10 @@ class Job:
self._success_callback_params = None
self._success_callback = UNEVALUATED
self._failure_callback_name = None
self._failure_callback_params = None
self._failure_callback = UNEVALUATED
self._stopped_callback_name = None
self._stopped_callback_params = None
self._stopped_callback = UNEVALUATED
self.description: Optional[str] = None
self.origin: str = ''
@ -953,12 +957,18 @@ class Job:
if obj.get('failure_callback_name'):
self._failure_callback_name = obj.get('failure_callback_name').decode()
if obj.get('failure_callback_params'):
self._failure_callback_params = json.loads(obj.get('failure_callback_params').decode())
if 'failure_callback_timeout' in obj:
self._failure_callback_timeout = int(obj.get('failure_callback_timeout'))
if obj.get('stopped_callback_name'):
self._stopped_callback_name = obj.get('stopped_callback_name').decode()
if obj.get('stopped_callback_params'):
self._stopped_callback_params = json.loads(obj.get('stopped_callback_params').decode())
if 'stopped_callback_timeout' in obj:
self._stopped_callback_timeout = int(obj.get('stopped_callback_timeout'))
@ -1023,8 +1033,6 @@ class Job:
'worker_name': self.worker_name or '',
}
if self._success_callback_params is not None:
obj['success_callback_params'] = json.dumps(self._success_callback_params)
if self.retries_left is not None:
obj['retries_left'] = self.retries_left
if self.retry_intervals is not None:
@ -1047,10 +1055,16 @@ class Job:
obj['timeout'] = self.timeout
if self._success_callback_timeout is not None:
obj['success_callback_timeout'] = self._success_callback_timeout
if self._success_callback_params is not None:
obj['success_callback_params'] = json.dumps(self._success_callback_params)
if self._failure_callback_timeout is not None:
obj['failure_callback_timeout'] = self._failure_callback_timeout
if self._failure_callback_params is not None:
obj['failure_callback_params'] = json.dumps(self._failure_callback_params)
if self._stopped_callback_timeout is not None:
obj['stopped_callback_timeout'] = self._stopped_callback_timeout
if self._stopped_callback_params is not None:
obj['stopped_callback_params'] = json.dumps(self._stopped_callback_params)
if self.result_ttl is not None:
obj['result_ttl'] = self.result_ttl
if self.failure_ttl is not None:
@ -1435,7 +1449,7 @@ class Job:
logger.debug('Running failure callbacks for %s', self.id)
try:
with death_penalty_class(self.failure_callback_timeout, JobTimeoutException, job_id=self.id):
self.failure_callback(self, self.connection, *exc_info)
self.failure_callback(self, self.connection, self._failure_callback_params, *exc_info)
except Exception: # noqa
logger.exception(f'Job {self.id}: error while executing failure callback')
raise
@ -1445,7 +1459,7 @@ class Job:
logger.debug('Running stopped callbacks for %s', self.id)
try:
with death_penalty_class(self.stopped_callback_timeout, JobTimeoutException, job_id=self.id):
self.stopped_callback(self, self.connection)
self.stopped_callback(self, self.connection self._stopped_callback_params)
except Exception: # noqa
logger.exception(f'Job {self.id}: error while executing stopped callback')
raise