Update Job.get_status and Job.restore to consistently return JobStatus Enum (#2039)

* fix #2038

* raise exception if there is no status in redis

* add get_status failure test
This commit is contained in:
Joe Carey 2024-05-25 20:37:55 -06:00 committed by GitHub
parent fc4884a0f2
commit 180c9afba0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 6 deletions

View File

@ -330,12 +330,17 @@ class Job:
Args:
refresh (bool, optional): Whether to refresh the Job. Defaults to True.
Raises:
InvalidJobOperation: If refreshing and nothing is returned from the `HGET` operation.
Returns:
status (JobStatus): The Job Status
"""
if refresh:
status = self.connection.hget(self.key, 'status')
self._status = as_text(status) if status else None
if not status:
raise InvalidJobOperation(f"Failed to retrieve status for job: {self.id}")
self._status = JobStatus(as_text(status))
return self._status
def set_status(self, status: JobStatus, pipeline: Optional['Pipeline'] = None) -> None:
@ -950,7 +955,7 @@ class Job:
self.timeout = parse_timeout(obj.get('timeout')) if obj.get('timeout') else None
self.result_ttl = int(obj.get('result_ttl')) if obj.get('result_ttl') else None
self.failure_ttl = int(obj.get('failure_ttl')) if obj.get('failure_ttl') else None
self._status = obj.get('status').decode() if obj.get('status') else None
self._status = JobStatus(as_text(obj.get('status'))) if obj.get('status') else None
if obj.get('success_callback_name'):
self._success_callback_name = obj.get('success_callback_name').decode()

View File

@ -48,7 +48,7 @@ def as_text(v: Union[bytes, str]) -> str:
ValueError: If the value is not bytes or string
Returns:
value (Optional[str]): Either the decoded string or None
value (str): The decoded string
"""
if isinstance(v, bytes):
return v.decode('utf-8')

View File

@ -554,6 +554,12 @@ class TestJob(RQTestCase):
self.assertIsNotNone(job.last_heartbeat)
self.assertIsNotNone(job.started_at)
def test_unset_job_status_fails(self):
"""None is an invalid status for Job."""
job = Job.create(func=fixtures.say_hello, connection=self.connection)
job.save()
self.assertRaises(InvalidJobOperation, job.get_status)
def test_job_access_outside_job_fails(self):
"""The current job is accessible only within a job context."""
self.assertIsNone(get_current_job())
@ -625,7 +631,7 @@ class TestJob(RQTestCase):
def test_cleanup(self):
"""Test that jobs and results are expired properly."""
job = Job.create(func=fixtures.say_hello, connection=self.connection)
job = Job.create(func=fixtures.say_hello, connection=self.connection, status=JobStatus.QUEUED)
job.save()
# Jobs with negative TTLs don't expire
@ -837,7 +843,11 @@ class TestJob(RQTestCase):
queue = Queue(connection=self.connection, serializer=JSONSerializer)
job = queue.enqueue(fixtures.say_hello)
job2 = Job.create(
func=fixtures.say_hello, depends_on=job, serializer=JSONSerializer, connection=self.connection
func=fixtures.say_hello,
depends_on=job,
serializer=JSONSerializer,
connection=self.connection,
status=JobStatus.QUEUED,
)
job2.register_dependency()
job2.save()
@ -866,7 +876,11 @@ class TestJob(RQTestCase):
queue = Queue(connection=self.connection, serializer=JSONSerializer)
dependency_job = queue.enqueue(fixtures.say_hello)
dependent_job = Job.create(
func=fixtures.say_hello, depends_on=dependency_job, serializer=JSONSerializer, connection=self.connection
func=fixtures.say_hello,
depends_on=dependency_job,
serializer=JSONSerializer,
connection=self.connection,
status=JobStatus.QUEUED,
)
dependent_job.register_dependency()

View File

@ -6,6 +6,7 @@ from redis import Redis
from rq.exceptions import TimeoutFormatError
from rq.utils import (
as_text,
backend_class,
ceildiv,
ensure_list,
@ -57,6 +58,15 @@ class TestUtils(RQTestCase):
self.assertEqual(True, is_nonstring_iterable({}))
self.assertEqual(True, is_nonstring_iterable(()))
def test_as_text(self):
"""Ensure function as_text works correctly"""
bad_texts = [3, None, 'test\xd0']
self.assertEqual('test', as_text(b'test'))
self.assertEqual('test', as_text('test'))
with self.assertRaises(ValueError):
for text in bad_texts:
as_text(text)
def test_ensure_list(self):
"""Ensure function ensure_list works correctly"""
self.assertEqual([], ensure_list([]))