[App/Improvement] Cleaning up Queue abstraction (#14977)

[App/Improvement] Cleaning up Queue abstraction (#14977)
This commit is contained in:
Sherin Thomas 2022-10-04 22:07:31 +05:30 committed by GitHub
parent ce919ee7d6
commit 53694eb93d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 50 deletions

View File

@ -22,12 +22,13 @@ from websockets.exceptions import ConnectionClosed
from lightning_app.api.http_methods import HttpMethod
from lightning_app.api.request_types import DeltaRequest
from lightning_app.core.constants import ENABLE_STATE_WEBSOCKET, FRONTEND_DIR
from lightning_app.core.queues import RedisQueue
from lightning_app.core.constants import CLOUD_QUEUE_TYPE, ENABLE_STATE_WEBSOCKET, FRONTEND_DIR
from lightning_app.core.queues import QueuingSystem
from lightning_app.storage import Drive
from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.enum import OpenAPITags
from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available
from lightning_app.utilities.imports import _is_starsessions_available
if _is_starsessions_available():
from starsessions import SessionMiddleware
@ -261,17 +262,13 @@ async def upload_file(filename: str, uploaded_file: UploadFile = File(...)):
@fastapi_service.get("/healthz", status_code=200)
async def healthz(response: Response):
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically. This requires
Redis to be installed for it to work.
# TODO - Once the state store abstraction is in, check that too
"""
if not _is_redis_available():
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"status": "failure", "reason": "Redis is not available"}
if not RedisQueue(name="ping", default_timeout=1).ping():
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"status": "failure", "reason": "Redis is not available"}
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
# check the queue status only if running in cloud
if is_running_in_cloud():
queue_obj = QueuingSystem(CLOUD_QUEUE_TYPE).get_queue(queue_name="healthz")
if not queue_obj.is_running:
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"status": "failure", "reason": "Redis is not available"}
x_lightning_session_uuid = TEST_SESSION_UUID
state = global_app_state_store.get_app_state(x_lightning_session_uuid)
global_app_state_store.set_served_state(x_lightning_session_uuid, state)

View File

@ -5,6 +5,12 @@ import lightning_cloud.env
import lightning_app
def get_lightning_cloud_url() -> str:
# DO NOT CHANGE!
return os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
SUPPORTED_PRIMITIVE_TYPES = (type(None), str, int, float, bool)
STATE_UPDATE_TIMEOUT = 0.001
STATE_ACCUMULATE_WAIT = 0.05
@ -17,11 +23,16 @@ FLOW_DURATION_SAMPLES = 5
APP_SERVER_HOST = os.getenv("LIGHTNING_APP_STATE_URL", "http://127.0.0.1")
APP_SERVER_PORT = 7501
APP_STATE_MAX_SIZE_BYTES = 1024 * 1024 # 1 MB
CLOUD_QUEUE_TYPE = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", "redis")
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
REDIS_WARNING_QUEUE_SIZE = 1000
WARNING_QUEUE_SIZE = 1000
USER_ID = os.getenv("USER_ID", "1234")
FRONTEND_DIR = os.path.join(os.path.dirname(lightning_app.__file__), "ui")
PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None)
@ -38,8 +49,3 @@ LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "0")))
DEBUG: bool = lightning_cloud.env.DEBUG
def get_lightning_cloud_url() -> str:
# DO NOT CHANGE!
return os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")

View File

@ -11,8 +11,8 @@ from lightning_app.core.constants import (
REDIS_PASSWORD,
REDIS_PORT,
REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
REDIS_WARNING_QUEUE_SIZE,
STATE_UPDATE_TIMEOUT,
WARNING_QUEUE_SIZE,
)
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_redis_available, requires
@ -44,7 +44,7 @@ class QueuingSystem(Enum):
MULTIPROCESS = "multiprocess"
REDIS = "redis"
def _get_queue(self, queue_name: str) -> "BaseQueue":
def get_queue(self, queue_name: str) -> "BaseQueue":
if self == QueuingSystem.MULTIPROCESS:
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
elif self == QueuingSystem.REDIS:
@ -54,37 +54,37 @@ class QueuingSystem(Enum):
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_error_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{ERROR_QUEUE_CONSTANT}" if queue_id else ERROR_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_has_server_started_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{HAS_SERVER_STARTED_CONSTANT}" if queue_id else HAS_SERVER_STARTED_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_caller_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
f"{queue_id}_{CALLER_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{CALLER_QUEUE_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_api_state_publish_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_STATE_PUBLISH_QUEUE_CONSTANT}" if queue_id else API_STATE_PUBLISH_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = f"{queue_id}_{API_DELTA_QUEUE_CONSTANT}" if queue_id else API_DELTA_QUEUE_CONSTANT
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
@ -92,7 +92,7 @@ class QueuingSystem(Enum):
if queue_id
else f"{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_orchestrator_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
@ -100,7 +100,7 @@ class QueuingSystem(Enum):
if queue_id
else f"{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_orchestrator_copy_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
@ -108,7 +108,7 @@ class QueuingSystem(Enum):
if queue_id
else f"{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_orchestrator_copy_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
@ -116,13 +116,13 @@ class QueuingSystem(Enum):
if queue_id
else f"{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
def get_work_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
queue_name = (
f"{queue_id}_{WORK_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{WORK_QUEUE_CONSTANT}_{work_name}"
)
return self._get_queue(queue_name)
return self.get_queue(queue_name)
class BaseQueue(ABC):
@ -149,6 +149,14 @@ class BaseQueue(ABC):
"""
pass
@property
def is_running(self) -> bool:
"""Returns True if the queue is running, False otherwise.
Child classes should override this property and implement custom logic as required
"""
return True
class SingleProcessQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float):
@ -213,20 +221,13 @@ class RedisQueue(BaseQueue):
self.default_timeout = default_timeout
self.redis = redis.Redis(host=host, port=port, password=password)
def ping(self):
"""Ping the redis server to see if it is alive."""
try:
return self.redis.ping()
except redis.exceptions.ConnectionError:
return False
def put(self, item: Any) -> None:
value = pickle.dumps(item)
queue_len = self.length()
if queue_len >= REDIS_WARNING_QUEUE_SIZE:
if queue_len >= WARNING_QUEUE_SIZE:
warnings.warn(
f"The Redis Queue {self.name} length is larger than the "
f"recommended length of {REDIS_WARNING_QUEUE_SIZE}. "
f"recommended length of {WARNING_QUEUE_SIZE}. "
f"Found {queue_len}. This might cause your application to crash, "
"please investigate this."
)
@ -282,3 +283,11 @@ class RedisQueue(BaseQueue):
"Please try running your app again. "
"If the issue persists, please contact support@lightning.ai"
)
@property
def is_running(self) -> bool:
"""Pinging the redis server to see if it is alive."""
try:
return self.redis.ping()
except redis.exceptions.ConnectionError:
return False

View File

@ -324,7 +324,8 @@ async def test_health_endpoint_success():
check_if_redis_running(), reason="this is testing the failure condition " "for which the redis should not run"
)
@pytest.mark.anyio
async def test_health_endpoint_failure():
async def test_health_endpoint_failure(monkeypatch):
monkeypatch.setenv("LIGHTNING_APP_STATE_URL", "http://someurl") # adding this to make is_running_in_cloud pass
async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
# will respond 503 if redis is not running
response = await client.get("/healthz")

View File

@ -51,19 +51,19 @@ def test_redis_queue():
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
def test_redis_ping_success():
def test_redis_health_check_success():
redis_queue = QueuingSystem.REDIS.get_readiness_queue()
assert redis_queue.ping()
assert redis_queue.is_running
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
assert redis_queue.ping()
assert redis_queue.is_running
@pytest.mark.skipif(not _is_redis_available(), reason="redis is required for this test.")
@pytest.mark.skipif(check_if_redis_running(), reason="This is testing the failure case when redis is not running")
def test_redis_ping_failure():
def test_redis_health_check_failure():
redis_queue = RedisQueue(name="test_queue", default_timeout=1)
assert not redis_queue.ping()
assert not redis_queue.is_running
@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")