[App/Improvement] Cleaning up Queue abstraction (#14977)
[App/Improvement] Cleaning up Queue abstraction (#14977)
This commit is contained in:
parent
ce919ee7d6
commit
53694eb93d
|
@ -22,12 +22,13 @@ from websockets.exceptions import ConnectionClosed
|
|||
|
||||
from lightning_app.api.http_methods import HttpMethod
|
||||
from lightning_app.api.request_types import DeltaRequest
|
||||
from lightning_app.core.constants import ENABLE_STATE_WEBSOCKET, FRONTEND_DIR
|
||||
from lightning_app.core.queues import RedisQueue
|
||||
from lightning_app.core.constants import CLOUD_QUEUE_TYPE, ENABLE_STATE_WEBSOCKET, FRONTEND_DIR
|
||||
from lightning_app.core.queues import QueuingSystem
|
||||
from lightning_app.storage import Drive
|
||||
from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
|
||||
from lightning_app.utilities.cloud import is_running_in_cloud
|
||||
from lightning_app.utilities.enum import OpenAPITags
|
||||
from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available
|
||||
from lightning_app.utilities.imports import _is_starsessions_available
|
||||
|
||||
if _is_starsessions_available():
|
||||
from starsessions import SessionMiddleware
|
||||
|
@ -261,17 +262,13 @@ async def upload_file(filename: str, uploaded_file: UploadFile = File(...)):
|
|||
|
||||
@fastapi_service.get("/healthz", status_code=200)
|
||||
async def healthz(response: Response):
|
||||
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically. This requires
|
||||
Redis to be installed for it to work.
|
||||
|
||||
# TODO - Once the state store abstraction is in, check that too
|
||||
"""
|
||||
if not _is_redis_available():
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return {"status": "failure", "reason": "Redis is not available"}
|
||||
if not RedisQueue(name="ping", default_timeout=1).ping():
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return {"status": "failure", "reason": "Redis is not available"}
|
||||
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
|
||||
# check the queue status only if running in cloud
|
||||
if is_running_in_cloud():
|
||||
queue_obj = QueuingSystem(CLOUD_QUEUE_TYPE).get_queue(queue_name="healthz")
|
||||
if not queue_obj.is_running:
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return {"status": "failure", "reason": "Redis is not available"}
|
||||
x_lightning_session_uuid = TEST_SESSION_UUID
|
||||
state = global_app_state_store.get_app_state(x_lightning_session_uuid)
|
||||
global_app_state_store.set_served_state(x_lightning_session_uuid, state)
|
||||
|
|
|
@ -5,6 +5,12 @@ import lightning_cloud.env
|
|||
|
||||
import lightning_app
|
||||
|
||||
|
||||
def get_lightning_cloud_url() -> str:
|
||||
# DO NOT CHANGE!
|
||||
return os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
|
||||
|
||||
|
||||
SUPPORTED_PRIMITIVE_TYPES = (type(None), str, int, float, bool)
|
||||
STATE_UPDATE_TIMEOUT = 0.001
|
||||
STATE_ACCUMULATE_WAIT = 0.05
|
||||
|
@ -17,11 +23,16 @@ FLOW_DURATION_SAMPLES = 5
|
|||
APP_SERVER_HOST = os.getenv("LIGHTNING_APP_STATE_URL", "http://127.0.0.1")
|
||||
APP_SERVER_PORT = 7501
|
||||
APP_STATE_MAX_SIZE_BYTES = 1024 * 1024 # 1 MB
|
||||
|
||||
CLOUD_QUEUE_TYPE = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", "redis")
|
||||
|
||||
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
|
||||
REDIS_WARNING_QUEUE_SIZE = 1000
|
||||
|
||||
WARNING_QUEUE_SIZE = 1000
|
||||
|
||||
USER_ID = os.getenv("USER_ID", "1234")
|
||||
FRONTEND_DIR = os.path.join(os.path.dirname(lightning_app.__file__), "ui")
|
||||
PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None)
|
||||
|
@ -38,8 +49,3 @@ LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
|
|||
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
|
||||
|
||||
DEBUG: bool = lightning_cloud.env.DEBUG
|
||||
|
||||
|
||||
def get_lightning_cloud_url() -> str:
|
||||
# DO NOT CHANGE!
|
||||
return os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
|
||||
|
|
|
@ -11,8 +11,8 @@ from lightning_app.core.constants import (
|
|||
REDIS_PASSWORD,
|
||||
REDIS_PORT,
|
||||
REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
|
||||
REDIS_WARNING_QUEUE_SIZE,
|
||||
STATE_UPDATE_TIMEOUT,
|
||||
WARNING_QUEUE_SIZE,
|
||||
)
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
from lightning_app.utilities.imports import _is_redis_available, requires
|
||||
|
@ -44,7 +44,7 @@ class QueuingSystem(Enum):
|
|||
MULTIPROCESS = "multiprocess"
|
||||
REDIS = "redis"
|
||||
|
||||
def _get_queue(self, queue_name: str) -> "BaseQueue":
|
||||
def get_queue(self, queue_name: str) -> "BaseQueue":
|
||||
if self == QueuingSystem.MULTIPROCESS:
|
||||
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
elif self == QueuingSystem.REDIS:
|
||||
|
@ -54,37 +54,37 @@ class QueuingSystem(Enum):
|
|||
|
||||
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_error_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{ERROR_QUEUE_CONSTANT}" if queue_id else ERROR_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_has_server_started_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{HAS_SERVER_STARTED_CONSTANT}" if queue_id else HAS_SERVER_STARTED_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_caller_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
f"{queue_id}_{CALLER_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{CALLER_QUEUE_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_api_state_publish_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{API_STATE_PUBLISH_QUEUE_CONSTANT}" if queue_id else API_STATE_PUBLISH_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{API_DELTA_QUEUE_CONSTANT}" if queue_id else API_DELTA_QUEUE_CONSTANT
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
|
@ -92,7 +92,7 @@ class QueuingSystem(Enum):
|
|||
if queue_id
|
||||
else f"{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_orchestrator_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
|
@ -100,7 +100,7 @@ class QueuingSystem(Enum):
|
|||
if queue_id
|
||||
else f"{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_orchestrator_copy_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
|
@ -108,7 +108,7 @@ class QueuingSystem(Enum):
|
|||
if queue_id
|
||||
else f"{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_orchestrator_copy_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
|
@ -116,13 +116,13 @@ class QueuingSystem(Enum):
|
|||
if queue_id
|
||||
else f"{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
def get_work_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = (
|
||||
f"{queue_id}_{WORK_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{WORK_QUEUE_CONSTANT}_{work_name}"
|
||||
)
|
||||
return self._get_queue(queue_name)
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
|
||||
class BaseQueue(ABC):
|
||||
|
@ -149,6 +149,14 @@ class BaseQueue(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Returns True if the queue is running, False otherwise.
|
||||
|
||||
Child classes should override this property and implement custom logic as required
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class SingleProcessQueue(BaseQueue):
|
||||
def __init__(self, name: str, default_timeout: float):
|
||||
|
@ -213,20 +221,13 @@ class RedisQueue(BaseQueue):
|
|||
self.default_timeout = default_timeout
|
||||
self.redis = redis.Redis(host=host, port=port, password=password)
|
||||
|
||||
def ping(self):
|
||||
"""Ping the redis server to see if it is alive."""
|
||||
try:
|
||||
return self.redis.ping()
|
||||
except redis.exceptions.ConnectionError:
|
||||
return False
|
||||
|
||||
def put(self, item: Any) -> None:
|
||||
value = pickle.dumps(item)
|
||||
queue_len = self.length()
|
||||
if queue_len >= REDIS_WARNING_QUEUE_SIZE:
|
||||
if queue_len >= WARNING_QUEUE_SIZE:
|
||||
warnings.warn(
|
||||
f"The Redis Queue {self.name} length is larger than the "
|
||||
f"recommended length of {REDIS_WARNING_QUEUE_SIZE}. "
|
||||
f"recommended length of {WARNING_QUEUE_SIZE}. "
|
||||
f"Found {queue_len}. This might cause your application to crash, "
|
||||
"please investigate this."
|
||||
)
|
||||
|
@ -282,3 +283,11 @@ class RedisQueue(BaseQueue):
|
|||
"Please try running your app again. "
|
||||
"If the issue persists, please contact support@lightning.ai"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Pinging the redis server to see if it is alive."""
|
||||
try:
|
||||
return self.redis.ping()
|
||||
except redis.exceptions.ConnectionError:
|
||||
return False
|
||||
|
|
|
@ -324,7 +324,8 @@ async def test_health_endpoint_success():
|
|||
check_if_redis_running(), reason="this is testing the failure condition " "for which the redis should not run"
|
||||
)
|
||||
@pytest.mark.anyio
|
||||
async def test_health_endpoint_failure():
|
||||
async def test_health_endpoint_failure(monkeypatch):
|
||||
monkeypatch.setenv("LIGHTNING_APP_STATE_URL", "http://someurl") # adding this to make is_running_in_cloud pass
|
||||
async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
|
||||
# will respond 503 if redis is not running
|
||||
response = await client.get("/healthz")
|
||||
|
|
|
@ -51,19 +51,19 @@ def test_redis_queue():
|
|||
|
||||
|
||||
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
||||
def test_redis_ping_success():
|
||||
def test_redis_health_check_success():
|
||||
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
|
||||
assert redis_queue.ping()
|
||||
assert redis_queue.is_running
|
||||
|
||||
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
|
||||
assert redis_queue.ping()
|
||||
assert redis_queue.is_running
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _is_redis_available(), reason="redis is required for this test.")
|
||||
@pytest.mark.skipif(check_if_redis_running(), reason="This is testing the failure case when redis is not running")
|
||||
def test_redis_ping_failure():
|
||||
def test_redis_health_check_failure():
|
||||
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
|
||||
assert not redis_queue.ping()
|
||||
assert not redis_queue.is_running
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
|
||||
|
|
Loading…
Reference in New Issue