App: Limit rate of requests to http queue (#18981)

This commit is contained in:
Ethan Harris 2023-11-10 10:26:58 +00:00 committed by GitHub
parent a9d427c4f8
commit 9085db4dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 8 deletions

View File

@ -53,6 +53,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801") HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1")) HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None) HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))
USER_ID = os.getenv("USER_ID", "1234") USER_ID = os.getenv("USER_ID", "1234")
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui") FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")

View File

@ -29,6 +29,7 @@ from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
from lightning.app.core.constants import ( from lightning.app.core.constants import (
HTTP_QUEUE_REFRESH_INTERVAL, HTTP_QUEUE_REFRESH_INTERVAL,
HTTP_QUEUE_REQUESTS_PER_SECOND,
HTTP_QUEUE_TOKEN, HTTP_QUEUE_TOKEN,
HTTP_QUEUE_URL, HTTP_QUEUE_URL,
LIGHTNING_DIR, LIGHTNING_DIR,
@ -77,7 +78,9 @@ class QueuingSystem(Enum):
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
if self == QueuingSystem.REDIS: if self == QueuingSystem.REDIS:
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT) return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT) return RateLimitedQueue(
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
)
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue": def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
@ -347,6 +350,45 @@ class RedisQueue(BaseQueue):
return cls(**state) return cls(**state)
class RateLimitedQueue(BaseQueue):
def __init__(self, queue: BaseQueue, requests_per_second: float):
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.
Args:
queue: The queue to wrap.
requests_per_second: The target number of get or put requests per second.
"""
self.name = queue.name
self.default_timeout = queue.default_timeout
self._queue = queue
self._seconds_per_request = 1 / requests_per_second
self._last_get = 0.0
self._last_put = 0.0
@property
def is_running(self) -> bool:
return self._queue.is_running
def _wait_until_allowed(self, last_time: float) -> None:
t = time.time()
diff = t - last_time
if diff < self._seconds_per_request:
time.sleep(self._seconds_per_request - diff)
def get(self, timeout: Optional[float] = None) -> Any:
self._wait_until_allowed(self._last_get)
self._last_get = time.time()
return self._queue.get(timeout=timeout)
def put(self, item: Any) -> None:
self._wait_until_allowed(self._last_put)
self._last_put = time.time()
return self._queue.put(item)
class HTTPQueue(BaseQueue): class HTTPQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float) -> None: def __init__(self, name: str, default_timeout: float) -> None:
""" """

View File

@ -8,8 +8,15 @@ import pytest
import requests_mock import requests_mock
from lightning.app import LightningFlow from lightning.app import LightningFlow
from lightning.app.core import queues from lightning.app.core import queues
from lightning.app.core.constants import HTTP_QUEUE_URL from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue from lightning.app.core.queues import (
READINESS_QUEUE_CONSTANT,
BaseQueue,
HTTPQueue,
QueuingSystem,
RateLimitedQueue,
RedisQueue,
)
from lightning.app.utilities.imports import _is_redis_available from lightning.app.utilities.imports import _is_redis_available
from lightning.app.utilities.redis import check_if_redis_running from lightning.app.utilities.redis import check_if_redis_running
@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):
class TestHTTPQueue: class TestHTTPQueue:
def test_http_queue_failure_on_queue_name(self): def test_http_queue_failure_on_queue_name(self):
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test") test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
with pytest.raises(ValueError, match="App ID couldn't be extracted"): with pytest.raises(ValueError, match="App ID couldn't be extracted"):
test_queue.put("test") test_queue.put("test")
@ -174,7 +181,7 @@ class TestHTTPQueue:
def test_http_queue_put(self, monkeypatch): def test_http_queue_put(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
test_obj = LightningFlow() test_obj = LightningFlow()
# mocking requests and responses # mocking requests and responses
@ -200,8 +207,7 @@ class TestHTTPQueue:
def test_http_queue_get(self, monkeypatch): def test_http_queue_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
adapter = requests_mock.Adapter() adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter) test_queue.client.session.mount("http://", adapter)
@ -218,7 +224,7 @@ class TestHTTPQueue:
def test_unreachable_queue(monkeypatch): def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue") test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
resp1 = mock.MagicMock() resp1 = mock.MagicMock()
resp1.status_code = 204 resp1.status_code = 204
@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch):
# Test backoff on queue.put # Test backoff on queue.put
test_queue.put("foo") test_queue.put("foo")
assert test_queue.client.post.call_count == 3 assert test_queue.client.post.call_count == 3
@mock.patch("lightning.app.core.queues.time.sleep")
def test_rate_limited_queue(mock_sleep):
sleeps = []
mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)
mock_queue = mock.MagicMock()
mock_queue.name = "inner_queue"
mock_queue.default_timeout = 10.0
rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)
assert rate_limited_queue.name == "inner_queue"
assert rate_limited_queue.default_timeout == 10.0
timeout = time.perf_counter() + 1
while time.perf_counter() + sum(sleeps) < timeout:
rate_limited_queue.get()
assert mock_queue.get.call_count == 2