231 lines
9.1 KiB
Python
231 lines
9.1 KiB
Python
import multiprocessing
|
|
import pickle
|
|
import queue
|
|
import time
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import requests_mock
|
|
|
|
from lightning_app import LightningFlow
|
|
from lightning_app.core import queues
|
|
from lightning_app.core.constants import HTTP_QUEUE_URL
|
|
from lightning_app.core.queues import BaseQueue, QueuingSystem, READINESS_QUEUE_CONSTANT, RedisQueue
|
|
from lightning_app.utilities.imports import _is_redis_available
|
|
from lightning_app.utilities.redis import check_if_redis_running
|
|
|
|
|
|
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
|
@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS])
|
|
def test_queue_api(queue_type, monkeypatch):
|
|
"""Test the Queue API.
|
|
|
|
This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
|
|
"""
|
|
import redis
|
|
|
|
blpop_out = (b"entry-id", pickle.dumps("test_entry"))
|
|
|
|
monkeypatch.setattr(redis.Redis, "blpop", lambda *args, **kwargs: blpop_out)
|
|
monkeypatch.setattr(redis.Redis, "rpush", lambda *args, **kwargs: None)
|
|
monkeypatch.setattr(redis.Redis, "set", lambda *args, **kwargs: None)
|
|
monkeypatch.setattr(redis.Redis, "get", lambda *args, **kwargs: None)
|
|
|
|
test_queue = queue_type.get_readiness_queue()
|
|
assert test_queue.name == READINESS_QUEUE_CONSTANT
|
|
assert isinstance(test_queue, BaseQueue)
|
|
test_queue.put("test_entry")
|
|
assert test_queue.get() == "test_entry"
|
|
|
|
|
|
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
|
def test_redis_queue():
|
|
queue_id = int(time.time())
|
|
queue1 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id))
|
|
queue2 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id + 1))
|
|
queue1.put("test_entry1")
|
|
queue2.put("test_entry2")
|
|
assert queue1.get() == "test_entry1"
|
|
assert queue2.get() == "test_entry2"
|
|
with pytest.raises(queue.Empty):
|
|
queue2.get(timeout=1)
|
|
queue1.put("test_entry1")
|
|
assert queue1.length() == 1
|
|
queue1.clear()
|
|
with pytest.raises(queue.Empty):
|
|
queue1.get(timeout=1)
|
|
|
|
|
|
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
|
def test_redis_health_check_success():
|
|
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
|
|
assert redis_queue.is_running
|
|
|
|
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
|
|
assert redis_queue.is_running
|
|
|
|
|
|
@pytest.mark.skipif(not _is_redis_available(), reason="redis is required for this test.")
|
|
@pytest.mark.skipif(check_if_redis_running(), reason="This is testing the failure case when redis is not running")
|
|
def test_redis_health_check_failure():
|
|
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
|
|
assert not redis_queue.is_running
|
|
|
|
|
|
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
|
|
def test_redis_credential(monkeypatch):
|
|
monkeypatch.setattr(queues, "REDIS_HOST", "test-host")
|
|
monkeypatch.setattr(queues, "REDIS_PORT", "test-port")
|
|
monkeypatch.setattr(queues, "REDIS_PASSWORD", "test-password")
|
|
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
|
|
assert redis_queue.redis.connection_pool.connection_kwargs["host"] == "test-host"
|
|
assert redis_queue.redis.connection_pool.connection_kwargs["port"] == "test-port"
|
|
assert redis_queue.redis.connection_pool.connection_kwargs["password"] == "test-password"
|
|
|
|
|
|
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
|
|
@mock.patch("lightning_app.core.queues.redis.Redis")
|
|
def test_redis_queue_read_timeout(redis_mock):
|
|
redis_mock.return_value.blpop.return_value = (b"READINESS_QUEUE", pickle.dumps("test_entry"))
|
|
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
|
|
|
|
# default timeout
|
|
assert redis_queue.get(timeout=0) == "test_entry"
|
|
assert redis_mock.return_value.blpop.call_args_list[0] == mock.call(["READINESS_QUEUE"], timeout=0.005)
|
|
|
|
# custom timeout
|
|
assert redis_queue.get(timeout=2) == "test_entry"
|
|
assert redis_mock.return_value.blpop.call_args_list[1] == mock.call(["READINESS_QUEUE"], timeout=2)
|
|
|
|
# blocking timeout
|
|
assert redis_queue.get() == "test_entry"
|
|
assert redis_mock.return_value.blpop.call_args_list[2] == mock.call(["READINESS_QUEUE"], timeout=0)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"queue_type, queue_process_mock",
|
|
[(QueuingSystem.MULTIPROCESS, multiprocessing)],
|
|
)
|
|
def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch):
|
|
|
|
context = mock.MagicMock()
|
|
queue_mocked = mock.MagicMock()
|
|
context.Queue = queue_mocked
|
|
monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context))
|
|
my_queue = queue_type.get_readiness_queue()
|
|
|
|
# default timeout
|
|
my_queue.get(timeout=0)
|
|
assert queue_mocked.return_value.get.call_args_list[0] == mock.call(timeout=0.001, block=False)
|
|
|
|
# custom timeout
|
|
my_queue.get(timeout=2)
|
|
assert queue_mocked.return_value.get.call_args_list[1] == mock.call(timeout=2, block=False)
|
|
|
|
# blocking timeout
|
|
my_queue.get()
|
|
assert queue_mocked.return_value.get.call_args_list[2] == mock.call(timeout=None, block=True)
|
|
|
|
|
|
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
|
@mock.patch("lightning_app.core.queues.WARNING_QUEUE_SIZE", 2)
|
|
def test_redis_queue_warning():
|
|
my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
|
|
my_queue.clear()
|
|
with pytest.warns(UserWarning, match="is larger than the"):
|
|
my_queue.put(None)
|
|
my_queue.put(None)
|
|
my_queue.put(None)
|
|
|
|
|
|
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
|
@mock.patch("lightning_app.core.queues.redis.Redis")
|
|
def test_redis_raises_error_if_failing(redis_mock):
|
|
import redis
|
|
|
|
my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
|
|
redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
|
|
redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
|
|
|
|
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
|
|
redis_mock.return_value.blpop.side_effect = redis.exceptions.ConnectionError("EROOOR")
|
|
my_queue.get()
|
|
|
|
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
|
|
redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
|
|
redis_mock.return_value.llen.return_value = 1
|
|
my_queue.put(1)
|
|
|
|
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
|
|
redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
|
|
my_queue.length()
|
|
|
|
|
|
class TestHTTPQueue:
|
|
def test_http_queue_failure_on_queue_name(self):
|
|
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
|
|
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
|
test_queue.put("test")
|
|
|
|
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
|
test_queue.get()
|
|
|
|
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
|
test_queue.length()
|
|
|
|
def test_http_queue_put(self, monkeypatch):
|
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
|
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
|
test_obj = LightningFlow()
|
|
|
|
# mocking requests and responses
|
|
adapter = requests_mock.Adapter()
|
|
test_queue.client.session.mount("http://", adapter)
|
|
adapter.register_uri(
|
|
"GET",
|
|
f"{HTTP_QUEUE_URL}/v1/test/http_queue/length",
|
|
request_headers={"Authorization": "Bearer test-token"},
|
|
status_code=200,
|
|
content=b"1",
|
|
)
|
|
adapter.register_uri(
|
|
"POST",
|
|
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=push",
|
|
status_code=201,
|
|
additional_matcher=lambda req: pickle.dumps(test_obj) == req._request.body,
|
|
request_headers={"Authorization": "Bearer test-token"},
|
|
content=b"data pushed",
|
|
)
|
|
|
|
test_queue.put(test_obj)
|
|
|
|
def test_http_queue_get(self, monkeypatch):
|
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
|
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
|
|
|
adapter = requests_mock.Adapter()
|
|
test_queue.client.session.mount("http://", adapter)
|
|
|
|
adapter.register_uri(
|
|
"POST",
|
|
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=pop",
|
|
request_headers={"Authorization": "Bearer test-token"},
|
|
status_code=200,
|
|
content=pickle.dumps("test"),
|
|
)
|
|
assert test_queue.get() == "test"
|
|
|
|
|
|
def test_unreachable_queue(monkeypatch):
|
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
|
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
|
|
|
resp = mock.MagicMock()
|
|
resp.status_code = 204
|
|
|
|
test_queue.client = mock.MagicMock()
|
|
test_queue.client.post.return_value = resp
|
|
|
|
with pytest.raises(queue.Empty):
|
|
test_queue._get()
|