App: Limit rate of requests to http queue (#18981)
This commit is contained in:
parent
a9d427c4f8
commit
9085db4dd3
|
@ -53,6 +53,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
|
|||
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
|
||||
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
|
||||
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
|
||||
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))
|
||||
|
||||
USER_ID = os.getenv("USER_ID", "1234")
|
||||
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
|
||||
|
|
|
@ -29,6 +29,7 @@ from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
|
|||
|
||||
from lightning.app.core.constants import (
|
||||
HTTP_QUEUE_REFRESH_INTERVAL,
|
||||
HTTP_QUEUE_REQUESTS_PER_SECOND,
|
||||
HTTP_QUEUE_TOKEN,
|
||||
HTTP_QUEUE_URL,
|
||||
LIGHTNING_DIR,
|
||||
|
@ -77,7 +78,9 @@ class QueuingSystem(Enum):
|
|||
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
if self == QueuingSystem.REDIS:
|
||||
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
|
||||
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
return RateLimitedQueue(
|
||||
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
|
||||
)
|
||||
|
||||
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
|
||||
|
@ -347,6 +350,45 @@ class RedisQueue(BaseQueue):
|
|||
return cls(**state)
|
||||
|
||||
|
||||
class RateLimitedQueue(BaseQueue):
|
||||
def __init__(self, queue: BaseQueue, requests_per_second: float):
|
||||
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.
|
||||
|
||||
Args:
|
||||
queue: The queue to wrap.
|
||||
requests_per_second: The target number of get or put requests per second.
|
||||
|
||||
"""
|
||||
self.name = queue.name
|
||||
self.default_timeout = queue.default_timeout
|
||||
|
||||
self._queue = queue
|
||||
self._seconds_per_request = 1 / requests_per_second
|
||||
|
||||
self._last_get = 0.0
|
||||
self._last_put = 0.0
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._queue.is_running
|
||||
|
||||
def _wait_until_allowed(self, last_time: float) -> None:
|
||||
t = time.time()
|
||||
diff = t - last_time
|
||||
if diff < self._seconds_per_request:
|
||||
time.sleep(self._seconds_per_request - diff)
|
||||
|
||||
def get(self, timeout: Optional[float] = None) -> Any:
|
||||
self._wait_until_allowed(self._last_get)
|
||||
self._last_get = time.time()
|
||||
return self._queue.get(timeout=timeout)
|
||||
|
||||
def put(self, item: Any) -> None:
|
||||
self._wait_until_allowed(self._last_put)
|
||||
self._last_put = time.time()
|
||||
return self._queue.put(item)
|
||||
|
||||
|
||||
class HTTPQueue(BaseQueue):
|
||||
def __init__(self, name: str, default_timeout: float) -> None:
|
||||
"""
|
||||
|
|
|
@ -8,8 +8,15 @@ import pytest
|
|||
import requests_mock
|
||||
from lightning.app import LightningFlow
|
||||
from lightning.app.core import queues
|
||||
from lightning.app.core.constants import HTTP_QUEUE_URL
|
||||
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue
|
||||
from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
|
||||
from lightning.app.core.queues import (
|
||||
READINESS_QUEUE_CONSTANT,
|
||||
BaseQueue,
|
||||
HTTPQueue,
|
||||
QueuingSystem,
|
||||
RateLimitedQueue,
|
||||
RedisQueue,
|
||||
)
|
||||
from lightning.app.utilities.imports import _is_redis_available
|
||||
from lightning.app.utilities.redis import check_if_redis_running
|
||||
|
||||
|
@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):
|
|||
|
||||
class TestHTTPQueue:
|
||||
def test_http_queue_failure_on_queue_name(self):
|
||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
|
||||
test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
|
||||
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
||||
test_queue.put("test")
|
||||
|
||||
|
@ -174,7 +181,7 @@ class TestHTTPQueue:
|
|||
|
||||
def test_http_queue_put(self, monkeypatch):
|
||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
||||
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||
test_obj = LightningFlow()
|
||||
|
||||
# mocking requests and responses
|
||||
|
@ -200,8 +207,7 @@ class TestHTTPQueue:
|
|||
|
||||
def test_http_queue_get(self, monkeypatch):
|
||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
||||
|
||||
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||
adapter = requests_mock.Adapter()
|
||||
test_queue.client.session.mount("http://", adapter)
|
||||
|
||||
|
@ -218,7 +224,7 @@ class TestHTTPQueue:
|
|||
def test_unreachable_queue(monkeypatch):
|
||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||
|
||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
||||
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||
|
||||
resp1 = mock.MagicMock()
|
||||
resp1.status_code = 204
|
||||
|
@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch):
|
|||
# Test backoff on queue.put
|
||||
test_queue.put("foo")
|
||||
assert test_queue.client.post.call_count == 3
|
||||
|
||||
|
||||
@mock.patch("lightning.app.core.queues.time.sleep")
|
||||
def test_rate_limited_queue(mock_sleep):
|
||||
sleeps = []
|
||||
mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)
|
||||
|
||||
mock_queue = mock.MagicMock()
|
||||
|
||||
mock_queue.name = "inner_queue"
|
||||
mock_queue.default_timeout = 10.0
|
||||
|
||||
rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)
|
||||
|
||||
assert rate_limited_queue.name == "inner_queue"
|
||||
assert rate_limited_queue.default_timeout == 10.0
|
||||
|
||||
timeout = time.perf_counter() + 1
|
||||
while time.perf_counter() + sum(sleeps) < timeout:
|
||||
rate_limited_queue.get()
|
||||
|
||||
assert mock_queue.get.call_count == 2
|
||||
|
|
Loading…
Reference in New Issue