App: Limit rate of requests to http queue (#18981)
This commit is contained in:
parent
a9d427c4f8
commit
9085db4dd3
|
@ -53,6 +53,7 @@ REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
|
||||||
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
|
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
|
||||||
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
|
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
|
||||||
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
|
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
|
||||||
|
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))
|
||||||
|
|
||||||
USER_ID = os.getenv("USER_ID", "1234")
|
USER_ID = os.getenv("USER_ID", "1234")
|
||||||
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
|
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
|
||||||
|
|
|
@ -29,6 +29,7 @@ from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
|
||||||
|
|
||||||
from lightning.app.core.constants import (
|
from lightning.app.core.constants import (
|
||||||
HTTP_QUEUE_REFRESH_INTERVAL,
|
HTTP_QUEUE_REFRESH_INTERVAL,
|
||||||
|
HTTP_QUEUE_REQUESTS_PER_SECOND,
|
||||||
HTTP_QUEUE_TOKEN,
|
HTTP_QUEUE_TOKEN,
|
||||||
HTTP_QUEUE_URL,
|
HTTP_QUEUE_URL,
|
||||||
LIGHTNING_DIR,
|
LIGHTNING_DIR,
|
||||||
|
@ -77,7 +78,9 @@ class QueuingSystem(Enum):
|
||||||
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||||
if self == QueuingSystem.REDIS:
|
if self == QueuingSystem.REDIS:
|
||||||
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
|
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
|
||||||
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
return RateLimitedQueue(
|
||||||
|
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
|
||||||
|
)
|
||||||
|
|
||||||
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||||
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
|
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
|
||||||
|
@ -347,6 +350,45 @@ class RedisQueue(BaseQueue):
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitedQueue(BaseQueue):
|
||||||
|
def __init__(self, queue: BaseQueue, requests_per_second: float):
|
||||||
|
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue: The queue to wrap.
|
||||||
|
requests_per_second: The target number of get or put requests per second.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.name = queue.name
|
||||||
|
self.default_timeout = queue.default_timeout
|
||||||
|
|
||||||
|
self._queue = queue
|
||||||
|
self._seconds_per_request = 1 / requests_per_second
|
||||||
|
|
||||||
|
self._last_get = 0.0
|
||||||
|
self._last_put = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._queue.is_running
|
||||||
|
|
||||||
|
def _wait_until_allowed(self, last_time: float) -> None:
|
||||||
|
t = time.time()
|
||||||
|
diff = t - last_time
|
||||||
|
if diff < self._seconds_per_request:
|
||||||
|
time.sleep(self._seconds_per_request - diff)
|
||||||
|
|
||||||
|
def get(self, timeout: Optional[float] = None) -> Any:
|
||||||
|
self._wait_until_allowed(self._last_get)
|
||||||
|
self._last_get = time.time()
|
||||||
|
return self._queue.get(timeout=timeout)
|
||||||
|
|
||||||
|
def put(self, item: Any) -> None:
|
||||||
|
self._wait_until_allowed(self._last_put)
|
||||||
|
self._last_put = time.time()
|
||||||
|
return self._queue.put(item)
|
||||||
|
|
||||||
|
|
||||||
class HTTPQueue(BaseQueue):
|
class HTTPQueue(BaseQueue):
|
||||||
def __init__(self, name: str, default_timeout: float) -> None:
|
def __init__(self, name: str, default_timeout: float) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,8 +8,15 @@ import pytest
|
||||||
import requests_mock
|
import requests_mock
|
||||||
from lightning.app import LightningFlow
|
from lightning.app import LightningFlow
|
||||||
from lightning.app.core import queues
|
from lightning.app.core import queues
|
||||||
from lightning.app.core.constants import HTTP_QUEUE_URL
|
from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
|
||||||
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue
|
from lightning.app.core.queues import (
|
||||||
|
READINESS_QUEUE_CONSTANT,
|
||||||
|
BaseQueue,
|
||||||
|
HTTPQueue,
|
||||||
|
QueuingSystem,
|
||||||
|
RateLimitedQueue,
|
||||||
|
RedisQueue,
|
||||||
|
)
|
||||||
from lightning.app.utilities.imports import _is_redis_available
|
from lightning.app.utilities.imports import _is_redis_available
|
||||||
from lightning.app.utilities.redis import check_if_redis_running
|
from lightning.app.utilities.redis import check_if_redis_running
|
||||||
|
|
||||||
|
@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):
|
||||||
|
|
||||||
class TestHTTPQueue:
|
class TestHTTPQueue:
|
||||||
def test_http_queue_failure_on_queue_name(self):
|
def test_http_queue_failure_on_queue_name(self):
|
||||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
|
test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
|
||||||
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
|
||||||
test_queue.put("test")
|
test_queue.put("test")
|
||||||
|
|
||||||
|
@ -174,7 +181,7 @@ class TestHTTPQueue:
|
||||||
|
|
||||||
def test_http_queue_put(self, monkeypatch):
|
def test_http_queue_put(self, monkeypatch):
|
||||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||||
test_obj = LightningFlow()
|
test_obj = LightningFlow()
|
||||||
|
|
||||||
# mocking requests and responses
|
# mocking requests and responses
|
||||||
|
@ -200,8 +207,7 @@ class TestHTTPQueue:
|
||||||
|
|
||||||
def test_http_queue_get(self, monkeypatch):
|
def test_http_queue_get(self, monkeypatch):
|
||||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||||
|
|
||||||
adapter = requests_mock.Adapter()
|
adapter = requests_mock.Adapter()
|
||||||
test_queue.client.session.mount("http://", adapter)
|
test_queue.client.session.mount("http://", adapter)
|
||||||
|
|
||||||
|
@ -218,7 +224,7 @@ class TestHTTPQueue:
|
||||||
def test_unreachable_queue(monkeypatch):
|
def test_unreachable_queue(monkeypatch):
|
||||||
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
|
||||||
|
|
||||||
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
|
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
|
||||||
|
|
||||||
resp1 = mock.MagicMock()
|
resp1 = mock.MagicMock()
|
||||||
resp1.status_code = 204
|
resp1.status_code = 204
|
||||||
|
@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch):
|
||||||
# Test backoff on queue.put
|
# Test backoff on queue.put
|
||||||
test_queue.put("foo")
|
test_queue.put("foo")
|
||||||
assert test_queue.client.post.call_count == 3
|
assert test_queue.client.post.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("lightning.app.core.queues.time.sleep")
|
||||||
|
def test_rate_limited_queue(mock_sleep):
|
||||||
|
sleeps = []
|
||||||
|
mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)
|
||||||
|
|
||||||
|
mock_queue = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_queue.name = "inner_queue"
|
||||||
|
mock_queue.default_timeout = 10.0
|
||||||
|
|
||||||
|
rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)
|
||||||
|
|
||||||
|
assert rate_limited_queue.name == "inner_queue"
|
||||||
|
assert rate_limited_queue.default_timeout == 10.0
|
||||||
|
|
||||||
|
timeout = time.perf_counter() + 1
|
||||||
|
while time.perf_counter() + sum(sleeps) < timeout:
|
||||||
|
rate_limited_queue.get()
|
||||||
|
|
||||||
|
assert mock_queue.get.call_count == 2
|
||||||
|
|
Loading…
Reference in New Issue