lightning/tests/tests_app/core/test_queues.py

231 lines
9.1 KiB
Python

import multiprocessing
import pickle
import queue
import time
from unittest import mock
import pytest
import requests_mock
from lightning_app import LightningFlow
from lightning_app.core import queues
from lightning_app.core.constants import HTTP_QUEUE_URL
from lightning_app.core.queues import BaseQueue, QueuingSystem, READINESS_QUEUE_CONSTANT, RedisQueue
from lightning_app.utilities.imports import _is_redis_available
from lightning_app.utilities.redis import check_if_redis_running
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS])
def test_queue_api(queue_type, monkeypatch):
"""Test the Queue API.
This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
"""
import redis
blpop_out = (b"entry-id", pickle.dumps("test_entry"))
monkeypatch.setattr(redis.Redis, "blpop", lambda *args, **kwargs: blpop_out)
monkeypatch.setattr(redis.Redis, "rpush", lambda *args, **kwargs: None)
monkeypatch.setattr(redis.Redis, "set", lambda *args, **kwargs: None)
monkeypatch.setattr(redis.Redis, "get", lambda *args, **kwargs: None)
test_queue = queue_type.get_readiness_queue()
assert test_queue.name == READINESS_QUEUE_CONSTANT
assert isinstance(test_queue, BaseQueue)
test_queue.put("test_entry")
assert test_queue.get() == "test_entry"
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
def test_redis_queue():
queue_id = int(time.time())
queue1 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id))
queue2 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id + 1))
queue1.put("test_entry1")
queue2.put("test_entry2")
assert queue1.get() == "test_entry1"
assert queue2.get() == "test_entry2"
with pytest.raises(queue.Empty):
queue2.get(timeout=1)
queue1.put("test_entry1")
assert queue1.length() == 1
queue1.clear()
with pytest.raises(queue.Empty):
queue1.get(timeout=1)
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
def test_redis_health_check_success():
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
assert redis_queue.is_running
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
assert redis_queue.is_running
@pytest.mark.skipif(not _is_redis_available(), reason="redis is required for this test.")
@pytest.mark.skipif(check_if_redis_running(), reason="This is testing the failure case when redis is not running")
def test_redis_health_check_failure():
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
assert not redis_queue.is_running
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
def test_redis_credential(monkeypatch):
monkeypatch.setattr(queues, "REDIS_HOST", "test-host")
monkeypatch.setattr(queues, "REDIS_PORT", "test-port")
monkeypatch.setattr(queues, "REDIS_PASSWORD", "test-password")
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
assert redis_queue.redis.connection_pool.connection_kwargs["host"] == "test-host"
assert redis_queue.redis.connection_pool.connection_kwargs["port"] == "test-port"
assert redis_queue.redis.connection_pool.connection_kwargs["password"] == "test-password"
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
@mock.patch("lightning_app.core.queues.redis.Redis")
def test_redis_queue_read_timeout(redis_mock):
redis_mock.return_value.blpop.return_value = (b"READINESS_QUEUE", pickle.dumps("test_entry"))
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
# default timeout
assert redis_queue.get(timeout=0) == "test_entry"
assert redis_mock.return_value.blpop.call_args_list[0] == mock.call(["READINESS_QUEUE"], timeout=0.005)
# custom timeout
assert redis_queue.get(timeout=2) == "test_entry"
assert redis_mock.return_value.blpop.call_args_list[1] == mock.call(["READINESS_QUEUE"], timeout=2)
# blocking timeout
assert redis_queue.get() == "test_entry"
assert redis_mock.return_value.blpop.call_args_list[2] == mock.call(["READINESS_QUEUE"], timeout=0)
@pytest.mark.parametrize(
"queue_type, queue_process_mock",
[(QueuingSystem.MULTIPROCESS, multiprocessing)],
)
def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch):
context = mock.MagicMock()
queue_mocked = mock.MagicMock()
context.Queue = queue_mocked
monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context))
my_queue = queue_type.get_readiness_queue()
# default timeout
my_queue.get(timeout=0)
assert queue_mocked.return_value.get.call_args_list[0] == mock.call(timeout=0.001, block=False)
# custom timeout
my_queue.get(timeout=2)
assert queue_mocked.return_value.get.call_args_list[1] == mock.call(timeout=2, block=False)
# blocking timeout
my_queue.get()
assert queue_mocked.return_value.get.call_args_list[2] == mock.call(timeout=None, block=True)
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
@mock.patch("lightning_app.core.queues.WARNING_QUEUE_SIZE", 2)
def test_redis_queue_warning():
my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
my_queue.clear()
with pytest.warns(UserWarning, match="is larger than the"):
my_queue.put(None)
my_queue.put(None)
my_queue.put(None)
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
@mock.patch("lightning_app.core.queues.redis.Redis")
def test_redis_raises_error_if_failing(redis_mock):
import redis
my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
redis_mock.return_value.blpop.side_effect = redis.exceptions.ConnectionError("EROOOR")
my_queue.get()
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
redis_mock.return_value.llen.return_value = 1
my_queue.put(1)
with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
my_queue.length()
class TestHTTPQueue:
def test_http_queue_failure_on_queue_name(self):
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.put("test")
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.get()
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.length()
def test_http_queue_put(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
test_obj = LightningFlow()
# mocking requests and responses
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)
adapter.register_uri(
"GET",
f"{HTTP_QUEUE_URL}/v1/test/http_queue/length",
request_headers={"Authorization": "Bearer test-token"},
status_code=200,
content=b"1",
)
adapter.register_uri(
"POST",
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=push",
status_code=201,
additional_matcher=lambda req: pickle.dumps(test_obj) == req._request.body,
request_headers={"Authorization": "Bearer test-token"},
content=b"data pushed",
)
test_queue.put(test_obj)
def test_http_queue_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)
adapter.register_uri(
"POST",
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=pop",
request_headers={"Authorization": "Bearer test-token"},
status_code=200,
content=pickle.dumps("test"),
)
assert test_queue.get() == "test"
def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
resp = mock.MagicMock()
resp.status_code = 204
test_queue.client = mock.MagicMock()
test_queue.client.post.return_value = resp
with pytest.raises(queue.Empty):
test_queue._get()