mirror of https://github.com/rq/rq.git
190 lines
7.5 KiB
Python
190 lines
7.5 KiB
Python
import datetime
|
|
import re
|
|
from unittest.mock import Mock
|
|
|
|
from redis import Redis
|
|
|
|
from rq.exceptions import TimeoutFormatError
|
|
from rq.utils import (
|
|
as_text,
|
|
backend_class,
|
|
ceildiv,
|
|
ensure_list,
|
|
first,
|
|
get_call_string,
|
|
get_version,
|
|
import_attribute,
|
|
is_nonstring_iterable,
|
|
parse_timeout,
|
|
split_list,
|
|
truncate_long_string,
|
|
utcparse,
|
|
)
|
|
from rq.worker import SimpleWorker
|
|
from tests import RQTestCase, fixtures
|
|
|
|
|
|
class TestUtils(RQTestCase):
|
|
def test_parse_timeout(self):
|
|
"""Ensure function parse_timeout works correctly"""
|
|
self.assertEqual(12, parse_timeout(12))
|
|
self.assertEqual(12, parse_timeout('12'))
|
|
self.assertEqual(12, parse_timeout('12s'))
|
|
self.assertEqual(720, parse_timeout('12m'))
|
|
self.assertEqual(3600, parse_timeout('1h'))
|
|
self.assertEqual(3600, parse_timeout('1H'))
|
|
|
|
def test_parse_timeout_coverage_scenarios(self):
|
|
"""Test parse_timeout edge cases for coverage"""
|
|
timeouts = ['h12', 'h', 'm', 's', '10k']
|
|
|
|
self.assertEqual(None, parse_timeout(None))
|
|
with self.assertRaises(TimeoutFormatError):
|
|
for timeout in timeouts:
|
|
parse_timeout(timeout)
|
|
|
|
def test_first(self):
|
|
"""Ensure function first works correctly"""
|
|
self.assertEqual(42, first([0, False, None, [], (), 42]))
|
|
self.assertEqual(None, first([0, False, None, [], ()]))
|
|
self.assertEqual('ohai', first([0, False, None, [], ()], default='ohai'))
|
|
self.assertEqual('bc', first(re.match(regex, 'abc') for regex in ['b.*', 'a(.*)']).group(1))
|
|
self.assertEqual(4, first([1, 1, 3, 4, 5], key=lambda x: x % 2 == 0))
|
|
|
|
def test_is_nonstring_iterable(self):
|
|
"""Ensure function is_nonstring_iterable works correctly"""
|
|
self.assertEqual(True, is_nonstring_iterable([]))
|
|
self.assertEqual(False, is_nonstring_iterable('test'))
|
|
self.assertEqual(True, is_nonstring_iterable({}))
|
|
self.assertEqual(True, is_nonstring_iterable(()))
|
|
|
|
def test_as_text(self):
|
|
"""Ensure function as_text works correctly"""
|
|
bad_texts = [3, None, 'test\xd0']
|
|
self.assertEqual('test', as_text(b'test'))
|
|
self.assertEqual('test', as_text('test'))
|
|
with self.assertRaises(ValueError):
|
|
for text in bad_texts:
|
|
as_text(text)
|
|
|
|
def test_ensure_list(self):
|
|
"""Ensure function ensure_list works correctly"""
|
|
self.assertEqual([], ensure_list([]))
|
|
self.assertEqual(['test'], ensure_list('test'))
|
|
self.assertEqual({}, ensure_list({}))
|
|
self.assertEqual((), ensure_list(()))
|
|
|
|
def test_utcparse(self):
|
|
"""Ensure function utcparse works correctly"""
|
|
utc_formated_time = '2017-08-31T10:14:02.123456Z'
|
|
self.assertEqual(datetime.datetime(2017, 8, 31, 10, 14, 2, 123456), utcparse(utc_formated_time))
|
|
|
|
def test_utcparse_legacy(self):
|
|
"""Ensure function utcparse works correctly"""
|
|
utc_formated_time = '2017-08-31T10:14:02Z'
|
|
self.assertEqual(datetime.datetime(2017, 8, 31, 10, 14, 2), utcparse(utc_formated_time))
|
|
|
|
def test_backend_class(self):
|
|
"""Ensure function backend_class works correctly"""
|
|
self.assertEqual(fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue'))
|
|
self.assertNotEqual(fixtures.say_pid, backend_class(fixtures, 'DummyQueue'))
|
|
self.assertEqual(fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue', override=fixtures.DummyQueue))
|
|
self.assertEqual(
|
|
fixtures.DummyQueue, backend_class(fixtures, 'DummyQueue', override='tests.fixtures.DummyQueue')
|
|
)
|
|
|
|
def test_get_redis_version(self):
|
|
"""Ensure get_version works properly"""
|
|
redis = Redis()
|
|
self.assertTrue(isinstance(get_version(redis), tuple))
|
|
|
|
# Parses 3 digit version numbers correctly
|
|
class DummyRedis(Redis):
|
|
def info(*args):
|
|
return {'redis_version': '4.0.8'}
|
|
|
|
self.assertEqual(get_version(DummyRedis()), (4, 0, 8))
|
|
|
|
# Parses 3 digit version numbers correctly
|
|
class DummyRedis(Redis):
|
|
def info(*args):
|
|
return {'redis_version': '3.0.7.9'}
|
|
|
|
self.assertEqual(get_version(DummyRedis()), (3, 0, 7))
|
|
|
|
# Parses 2 digit version numbers correctly (Seen in AWS ElastiCache Redis)
|
|
class DummyRedis(Redis):
|
|
def info(*args):
|
|
return {'redis_version': '7.1'}
|
|
|
|
self.assertEqual(get_version(DummyRedis()), (7, 1, 0))
|
|
|
|
# Parses 2 digit float version numbers correctly (Seen in AWS ElastiCache Redis)
|
|
class DummyRedis(Redis):
|
|
def info(*args):
|
|
return {'redis_version': 7.1}
|
|
|
|
self.assertEqual(get_version(DummyRedis()), (7, 1, 0))
|
|
|
|
def test_get_redis_version_gets_cached(self):
|
|
"""Ensure get_version works properly"""
|
|
# Parses 3 digit version numbers correctly
|
|
redis = Mock(spec=['info'])
|
|
redis.info = Mock(return_value={'redis_version': '4.0.8'})
|
|
self.assertEqual(get_version(redis), (4, 0, 8))
|
|
self.assertEqual(get_version(redis), (4, 0, 8))
|
|
redis.info.assert_called_once()
|
|
|
|
def test_import_attribute(self):
|
|
"""Ensure get_version works properly"""
|
|
self.assertEqual(import_attribute('rq.utils.get_version'), get_version)
|
|
self.assertEqual(import_attribute('rq.worker.SimpleWorker'), SimpleWorker)
|
|
self.assertRaises(ValueError, import_attribute, 'non.existent.module')
|
|
self.assertRaises(ValueError, import_attribute, 'rq.worker.WrongWorker')
|
|
|
|
def test_ceildiv_even(self):
|
|
"""When a number is evenly divisible by another ceildiv returns the quotient"""
|
|
dividend = 12
|
|
divisor = 4
|
|
self.assertEqual(ceildiv(dividend, divisor), dividend // divisor)
|
|
|
|
def test_ceildiv_uneven(self):
|
|
"""When a number is not evenly divisible by another ceildiv returns the quotient plus one"""
|
|
dividend = 13
|
|
divisor = 4
|
|
self.assertEqual(ceildiv(dividend, divisor), dividend // divisor + 1)
|
|
|
|
def test_split_list(self):
|
|
"""Ensure split_list works properly"""
|
|
BIG_LIST_SIZE = 42
|
|
SEGMENT_SIZE = 5
|
|
|
|
big_list = ['1'] * BIG_LIST_SIZE
|
|
small_lists = list(split_list(big_list, SEGMENT_SIZE))
|
|
|
|
expected_small_list_count = ceildiv(BIG_LIST_SIZE, SEGMENT_SIZE)
|
|
self.assertEqual(len(small_lists), expected_small_list_count)
|
|
|
|
def test_truncate_long_string(self):
|
|
"""Ensure truncate_long_string works properly"""
|
|
assert truncate_long_string("12", max_length=3) == "12"
|
|
assert truncate_long_string("123", max_length=3) == "123"
|
|
assert truncate_long_string("1234", max_length=3) == "123..."
|
|
assert truncate_long_string("12345", max_length=3) == "123..."
|
|
|
|
s = "long string but no max_length provided so no truncating should occur" * 10
|
|
assert truncate_long_string(s) == s
|
|
|
|
def test_get_call_string(self):
|
|
"""Ensure a case, when func_name, args and kwargs are not None, works properly"""
|
|
cs = get_call_string("f", ('some', 'args', 42), {"key1": "value1", "key2": True})
|
|
assert cs == "f('some', 'args', 42, key1='value1', key2=True)"
|
|
|
|
def test_get_call_string_with_max_length(self):
|
|
"""Ensure get_call_string works properly when max_length is provided"""
|
|
func_name = "f"
|
|
args = (1234, 12345, 123456)
|
|
kwargs = {"len4": 1234, "len5": 12345, "len6": 123456}
|
|
cs = get_call_string(func_name, args, kwargs, max_length=5)
|
|
assert cs == "f(1234, 12345, 12345..., len4=1234, len5=12345, len6=12345...)"
|