[App] Introduce auto scaler (#15769)
* Exlucde __pycache__ in setuptools * Add load balancer example * wip * Update example * rename * remove prints * _LoadBalancer -> LoadBalancer * AutoScaler(work) * change var name * remove locust * Update docs * include autoscaler in api ref * docs typo * docs typo * docs typo * docs typo * remove unused loadtest * remove unused device_type * clean up * clean up * clean up * Add docstring * type * env vars to args * expose an API for users to override to customise autoscaling logic * update example * comment * udpate var name * fix scale mechanism and clean up * Update exampl * ignore mypy * Add test file * . * update impl and update tests * Update changlog * . * revert docs * update test * update state to keep calling 'flow.run()' Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> * Add aiohttp to base requirements * Update docs Co-authored-by: Luca Antiga <luca.antiga@gmail.com> * Use deserializer utility * fake trigger * wip: protect /system/* with basic auth * read password at runtime * Change env var name * import torch as optional * Don't overcreate works * simplify imports * Update example * aiohttp * Add work_args work_kwargs * More docs * remove FIXME * Apply Jirka's suggestions Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean example device * add comment on init threshold value * bad merge * nit: logging format * {in,out}put_schema -> {in,out}put_type * lowercase * docs on seconds * process_time -> processing_time * Dont modify work state from flow * Update tests * worker_url -> endpoint * fix exampl * Fix default scale logic * Fix default scale logic * Fix num_pending_works * Update num_pending_works * Fix bug creating too many works * Remove up/downscale_threshold args * Update example * Add typing * Fix example in docstring * Fix default scale logic * Update src/lightning_app/components/auto_scaler.py Co-authored-by: Noha Alon <nohalon@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename method * rename locvar * Add todo * docs ci * docs ci * asdfafsdasdf pls docs * Apply suggestions from code review Co-authored-by: Ethan Harris <ethanwharris@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * . * doc * Update src/lightning_app/components/auto_scaler.py Co-authored-by: Noha Alon <nohalon@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit24983a0a5a
. * Revert "Update src/lightning_app/components/auto_scaler.py" This reverts commit56ea78b45f
. * Remove redefinition * Remove load balancer run blocker * raise RuntimeError * remove has_sent * lower the default timeout_batching from 10 to 1 * remove debug * update the default timeout_batching * . * tighten condition * fix endpoint * typo in runtimeerror cond * async lock update severs * add a test * {in,out}put_type typing * Update examples/app_server_with_auto_scaler/app.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> * Update .actions/setup_tools.py Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noha Alon <nohalon@gmail.com> Co-authored-by: Ethan Harris <ethanwharris@gmail.com> Co-authored-by: Akihiro Nitta <aki@pop-os.localdomain> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
6aaac8b910
commit
64b19fb16f
|
@ -37,6 +37,7 @@ ___________________
|
|||
~training.LightningTrainerScript
|
||||
~serve.gradio.ServeGradio
|
||||
~serve.serve.ModelInferenceAPI
|
||||
~auto_scaler.AutoScaler
|
||||
|
||||
----
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lightning as L
|
||||
|
||||
|
||||
class RequestModel(BaseModel):
|
||||
image: str # bytecode
|
||||
|
||||
|
||||
class BatchRequestModel(BaseModel):
|
||||
inputs: List[RequestModel]
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
outputs: List[Any]
|
||||
|
||||
|
||||
class PyTorchServer(L.app.components.PythonServer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(
|
||||
port=L.app.utilities.network.find_free_network_port(),
|
||||
input_type=BatchRequestModel,
|
||||
output_type=BatchResponse,
|
||||
cloud_compute=L.CloudCompute("gpu"),
|
||||
)
|
||||
|
||||
def setup(self):
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self._model = torchvision.models.resnet18(pretrained=True).to(self._device)
|
||||
|
||||
def predict(self, requests: BatchRequestModel):
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize(224),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
images = []
|
||||
for request in requests.inputs:
|
||||
image = L.app.components.serve.types.image.Image.deserialize(request.image)
|
||||
image = transforms(image).unsqueeze(0)
|
||||
images.append(image)
|
||||
images = torch.cat(images)
|
||||
images = images.to(self._device)
|
||||
predictions = self._model(images)
|
||||
results = predictions.argmax(1).cpu().numpy().tolist()
|
||||
return BatchResponse(outputs=[{"prediction": pred} for pred in results])
|
||||
|
||||
|
||||
class MyAutoScaler(L.app.components.AutoScaler):
|
||||
def scale(self, replicas: int, metrics: dict) -> int:
|
||||
"""The default scaling logic that users can override."""
|
||||
# scale out if the number of pending requests exceeds max batch size.
|
||||
max_requests_per_work = self.max_batch_size
|
||||
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
|
||||
replicas + metrics["pending_works"]
|
||||
)
|
||||
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
|
||||
return replicas + 1
|
||||
|
||||
# scale in if the number of pending requests is below 25% of max_requests_per_work
|
||||
min_requests_per_work = max_requests_per_work * 0.25
|
||||
pending_requests_per_running_work = metrics["pending_requests"] / replicas
|
||||
if pending_requests_per_running_work < min_requests_per_work:
|
||||
return replicas - 1
|
||||
|
||||
return replicas
|
||||
|
||||
|
||||
app = L.LightningApp(
|
||||
MyAutoScaler(
|
||||
PyTorchServer,
|
||||
min_replicas=2,
|
||||
max_replicas=4,
|
||||
autoscale_interval=10,
|
||||
endpoint="predict",
|
||||
input_type=RequestModel,
|
||||
output_type=Any,
|
||||
timeout_batching=1,
|
||||
)
|
||||
)
|
|
@ -80,6 +80,7 @@ module = [
|
|||
"lightning_app.components.serve.types.type",
|
||||
"lightning_app.components.serve.python_server",
|
||||
"lightning_app.components.training",
|
||||
"lightning_app.components.auto_scaler",
|
||||
"lightning_app.core.api",
|
||||
"lightning_app.core.app",
|
||||
"lightning_app.core.flow",
|
||||
|
|
|
@ -12,3 +12,4 @@ beautifulsoup4>=4.8.0, <4.11.2
|
|||
inquirer>=2.10.0
|
||||
psutil<5.9.4
|
||||
click<=8.1.3
|
||||
aiohttp>=3.8.0, <=3.8.3
|
||||
|
|
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Added
|
||||
|
||||
- Added the CLI command `lightning run model` to launch a `LightningLite` accelerated script ([#15506](https://github.com/Lightning-AI/lightning/pull/15506))
|
||||
|
||||
- Added the CLI command `lightning delete app` to delete a lightning app on the cloud ([#15783](https://github.com/Lightning-AI/lightning/pull/15783))
|
||||
|
||||
- Show a message when `BuildConfig(requirements=[...])` is passed but a `requirements.txt` file is already present in the Work ([#15799](https://github.com/Lightning-AI/lightning/pull/15799))
|
||||
|
@ -17,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))
|
||||
|
||||
- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769))
|
||||
|
||||
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
|
||||
|
||||
- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from lightning_app.components.auto_scaler import AutoScaler
|
||||
from lightning_app.components.database.client import DatabaseClient
|
||||
from lightning_app.components.database.server import Database
|
||||
from lightning_app.components.multi_node import (
|
||||
|
@ -15,6 +16,7 @@ from lightning_app.components.serve.streamlit import ServeStreamlit
|
|||
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
|
||||
|
||||
__all__ = [
|
||||
"AutoScaler",
|
||||
"DatabaseClient",
|
||||
"Database",
|
||||
"PopenPythonScript",
|
||||
|
|
|
@ -0,0 +1,568 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from base64 import b64encode
|
||||
from itertools import cycle
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.client_exceptions
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from lightning_app.core.flow import LightningFlow
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
logger = Logger(__name__)
|
||||
lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _raise_granular_exception(exception: Exception) -> None:
|
||||
"""Handle an exception from hitting the model servers."""
|
||||
if not isinstance(exception, Exception):
|
||||
return
|
||||
|
||||
if isinstance(exception, HTTPException):
|
||||
raise exception
|
||||
|
||||
if isinstance(exception, aiohttp.client_exceptions.ServerDisconnectedError):
|
||||
raise HTTPException(500, "Worker Server Disconnected") from exception
|
||||
|
||||
if isinstance(exception, aiohttp.client_exceptions.ClientError):
|
||||
logging.exception(exception)
|
||||
raise HTTPException(500, "Worker Server error") from exception
|
||||
|
||||
if isinstance(exception, asyncio.TimeoutError):
|
||||
raise HTTPException(408, "Request timed out") from exception
|
||||
|
||||
if isinstance(exception, Exception):
|
||||
if exception.args[0] == "Server disconnected":
|
||||
raise HTTPException(500, "Worker Server disconnected") from exception
|
||||
|
||||
logging.exception(exception)
|
||||
raise HTTPException(500, exception.args[0]) from exception
|
||||
|
||||
|
||||
class _SysInfo(BaseModel):
|
||||
num_workers: int
|
||||
servers: List[str]
|
||||
num_requests: int
|
||||
processing_time: int
|
||||
global_request_count: int
|
||||
|
||||
|
||||
class _BatchRequestModel(BaseModel):
|
||||
inputs: List[Any]
|
||||
|
||||
|
||||
def _create_fastapi(title: str) -> FastAPI:
|
||||
fastapi_app = FastAPI(title=title)
|
||||
|
||||
fastapi_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
fastapi_app.global_request_count = 0
|
||||
fastapi_app.num_current_requests = 0
|
||||
fastapi_app.last_processing_time = 0
|
||||
|
||||
@fastapi_app.get("/", include_in_schema=False)
|
||||
async def docs():
|
||||
return RedirectResponse("/docs")
|
||||
|
||||
@fastapi_app.get("/num-requests")
|
||||
async def num_requests() -> int:
|
||||
return fastapi_app.num_current_requests
|
||||
|
||||
return fastapi_app
|
||||
|
||||
|
||||
class _LoadBalancer(LightningWork):
|
||||
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
|
||||
asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
|
||||
|
||||
The LoadBalancer exposes system endpoints with a basic HTTP authentication, in order to activate the authentication
|
||||
you need to provide a system password from environment variable::
|
||||
|
||||
lightning run app app.py --env AUTO_SCALER_AUTH_PASSWORD=PASSWORD
|
||||
|
||||
After enabling you will require to send username and password from the request header for the private endpoints.
|
||||
|
||||
Args:
|
||||
input_type: Input type.
|
||||
output_type: Output type.
|
||||
endpoint: The REST API path.
|
||||
max_batch_size: The number of requests processed at once.
|
||||
timeout_batching: The number of seconds to wait before sending the requests to process in order to allow for
|
||||
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
|
||||
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
|
||||
timeout_inference_request: The number of seconds to wait for inference.
|
||||
\**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_type: BaseModel,
|
||||
output_type: BaseModel,
|
||||
endpoint: str,
|
||||
max_batch_size: int = 8,
|
||||
# all timeout args are in seconds
|
||||
timeout_batching: int = 1,
|
||||
timeout_keep_alive: int = 60,
|
||||
timeout_inference_request: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
|
||||
self._input_type = input_type
|
||||
self._output_type = output_type
|
||||
self._timeout_keep_alive = timeout_keep_alive
|
||||
self._timeout_inference_request = timeout_inference_request
|
||||
self.servers = []
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_batching = timeout_batching
|
||||
self._iter = None
|
||||
self._batch = []
|
||||
self._responses = {} # {request_id: response}
|
||||
self._last_batch_sent = 0
|
||||
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
|
||||
self.endpoint = endpoint
|
||||
|
||||
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
|
||||
server = next(self._iter) # round-robin
|
||||
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
|
||||
batch_request_data = _BatchRequestModel(inputs=request_data)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with session.post(
|
||||
f"{server}{self.endpoint}",
|
||||
json=batch_request_data.dict(),
|
||||
timeout=self._timeout_inference_request,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == 408:
|
||||
raise HTTPException(408, "Request timed out")
|
||||
response.raise_for_status()
|
||||
response = await response.json()
|
||||
outputs = response["outputs"]
|
||||
if len(batch) != len(outputs):
|
||||
raise RuntimeError(f"result has {len(outputs)} items but batch is {len(batch)}")
|
||||
result = {request[0]: r for request, r in zip(batch, outputs)}
|
||||
self._responses.update(result)
|
||||
except Exception as ex:
|
||||
result = {request[0]: ex for request in batch}
|
||||
self._responses.update(result)
|
||||
|
||||
async def consumer(self):
|
||||
while True:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
batch = self._batch[: self.max_batch_size]
|
||||
while batch and (
|
||||
(len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching)
|
||||
):
|
||||
asyncio.create_task(self.send_batch(batch))
|
||||
|
||||
self._batch = self._batch[self.max_batch_size :]
|
||||
batch = self._batch[: self.max_batch_size]
|
||||
self._last_batch_sent = time.time()
|
||||
|
||||
async def process_request(self, data: BaseModel):
|
||||
if not self.servers:
|
||||
raise HTTPException(500, "None of the workers are healthy!")
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
request: Tuple = (request_id, data)
|
||||
self._batch.append(request)
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if request_id in self._responses:
|
||||
result = self._responses[request_id]
|
||||
del self._responses[request_id]
|
||||
_raise_granular_exception(result)
|
||||
return result
|
||||
|
||||
def run(self):
|
||||
|
||||
logger.info(f"servers: {self.servers}")
|
||||
|
||||
self._iter = cycle(self.servers)
|
||||
self._last_batch_sent = time.time()
|
||||
|
||||
fastapi_app = _create_fastapi("Load Balancer")
|
||||
security = HTTPBasic()
|
||||
fastapi_app.SEND_TASK = None
|
||||
|
||||
@fastapi_app.middleware("http")
|
||||
async def current_request_counter(request: Request, call_next):
|
||||
if not request.scope["path"] == self.endpoint:
|
||||
return await call_next(request)
|
||||
fastapi_app.global_request_count += 1
|
||||
fastapi_app.num_current_requests += 1
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
processing_time = time.time() - start_time
|
||||
fastapi_app.last_processing_time = processing_time
|
||||
fastapi_app.num_current_requests -= 1
|
||||
return response
|
||||
|
||||
@fastapi_app.on_event("startup")
|
||||
async def startup_event():
|
||||
fastapi_app.SEND_TASK = asyncio.create_task(self.consumer())
|
||||
|
||||
@fastapi_app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
fastapi_app.SEND_TASK.cancel()
|
||||
|
||||
def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "")
|
||||
if len(AUTO_SCALER_AUTH_PASSWORD) == 0:
|
||||
logger.warn(
|
||||
"You have not set a password for private endpoints! To set a password, add "
|
||||
"`--env AUTO_SCALER_AUTH_PASSWORD=<your pass>` to your lightning run command."
|
||||
)
|
||||
current_password_bytes = credentials.password.encode("utf8")
|
||||
is_correct_password = secrets.compare_digest(
|
||||
current_password_bytes, AUTO_SCALER_AUTH_PASSWORD.encode("utf8")
|
||||
)
|
||||
if not is_correct_password:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Incorrect password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return True
|
||||
|
||||
@fastapi_app.get("/system/info", response_model=_SysInfo)
|
||||
async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)):
|
||||
return _SysInfo(
|
||||
num_workers=len(self.servers),
|
||||
servers=self.servers,
|
||||
num_requests=fastapi_app.num_current_requests,
|
||||
processing_time=fastapi_app.last_processing_time,
|
||||
global_request_count=fastapi_app.global_request_count,
|
||||
)
|
||||
|
||||
@fastapi_app.put("/system/update-servers")
|
||||
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
|
||||
async with lock:
|
||||
self.servers = servers
|
||||
|
||||
self._iter = cycle(self.servers)
|
||||
|
||||
@fastapi_app.post(self.endpoint, response_model=self._output_type)
|
||||
async def balance_api(inputs: self._input_type):
|
||||
return await self.process_request(inputs)
|
||||
|
||||
uvicorn.run(
|
||||
fastapi_app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
loop="uvloop",
|
||||
timeout_keep_alive=self._timeout_keep_alive,
|
||||
access_log=False,
|
||||
)
|
||||
|
||||
def update_servers(self, server_works: List[LightningWork]):
|
||||
"""Updates works that load balancer distributes requests to.
|
||||
|
||||
AutoScaler uses this method to increase/decrease the number of works.
|
||||
"""
|
||||
old_servers = set(self.servers)
|
||||
server_urls: List[str] = [server.url for server in server_works if server.url]
|
||||
new_servers = set(server_urls)
|
||||
|
||||
if new_servers == old_servers:
|
||||
return
|
||||
|
||||
if new_servers - old_servers:
|
||||
logger.info(f"servers added: {new_servers - old_servers}")
|
||||
|
||||
deleted_servers = old_servers - new_servers
|
||||
if deleted_servers:
|
||||
logger.info(f"servers deleted: {deleted_servers}")
|
||||
|
||||
self.send_request_to_update_servers(server_urls)
|
||||
|
||||
def send_request_to_update_servers(self, servers: List[str]):
|
||||
AUTHORIZATION_TYPE = "Basic"
|
||||
USERNAME = "lightning"
|
||||
AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "")
|
||||
|
||||
try:
|
||||
param = f"{USERNAME}:{AUTO_SCALER_AUTH_PASSWORD}".encode()
|
||||
data = b64encode(param).decode("utf-8")
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
) from e
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"username": USERNAME,
|
||||
"Authorization": AUTHORIZATION_TYPE + " " + data,
|
||||
}
|
||||
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
class AutoScaler(LightningFlow):
|
||||
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
|
||||
response to changes in the number of incoming requests. Incoming requests will be batched and balanced across
|
||||
the replicas.
|
||||
|
||||
Args:
|
||||
min_replicas: The number of works to start when app initializes.
|
||||
max_replicas: The max number of works to spawn to handle the incoming requests.
|
||||
autoscale_interval: The number of seconds to wait before checking whether to upscale or downscale the works.
|
||||
endpoint: Provide the REST API path.
|
||||
max_batch_size: (auto-batching) The number of requests to process at once.
|
||||
timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
|
||||
input_type: Input type.
|
||||
output_type: Output type.
|
||||
|
||||
.. testcode::
|
||||
|
||||
import lightning as L
|
||||
|
||||
# Example 1: Auto-scaling serve component out-of-the-box
|
||||
app = L.LightningApp(
|
||||
L.app.components.AutoScaler(
|
||||
MyPythonServer,
|
||||
min_replicas=1,
|
||||
max_replicas=8,
|
||||
autoscale_interval=10,
|
||||
)
|
||||
)
|
||||
|
||||
# Example 2: Customizing the scaling logic
|
||||
class MyAutoScaler(L.app.components.AutoScaler):
|
||||
def scale(self, replicas: int, metrics: dict) -> int:
|
||||
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
|
||||
replicas + metrics["pending_works"]
|
||||
)
|
||||
|
||||
# upscale
|
||||
max_requests_per_work = self.max_batch_size
|
||||
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
|
||||
return replicas + 1
|
||||
|
||||
# downscale
|
||||
min_requests_per_work = max_requests_per_work * 0.25
|
||||
if pending_requests_per_running_or_pending_work < min_requests_per_work:
|
||||
return replicas - 1
|
||||
|
||||
return replicas
|
||||
|
||||
|
||||
app = L.LightningApp(
|
||||
MyAutoScaler(
|
||||
MyPythonServer,
|
||||
min_replicas=1,
|
||||
max_replicas=8,
|
||||
autoscale_interval=10,
|
||||
max_batch_size=8, # for auto batching
|
||||
timeout_batching=1, # for auto batching
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
work_cls: Type[LightningWork],
|
||||
min_replicas: int = 1,
|
||||
max_replicas: int = 4,
|
||||
autoscale_interval: int = 10,
|
||||
max_batch_size: int = 8,
|
||||
timeout_batching: float = 1,
|
||||
endpoint: str = "api/predict",
|
||||
input_type: BaseModel = Dict,
|
||||
output_type: BaseModel = Dict,
|
||||
*work_args: Any,
|
||||
**work_kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_replicas = 0
|
||||
self._work_registry = {}
|
||||
|
||||
self._work_cls = work_cls
|
||||
self._work_args = work_args
|
||||
self._work_kwargs = work_kwargs
|
||||
|
||||
self._input_type = input_type
|
||||
self._output_type = output_type
|
||||
self.autoscale_interval = autoscale_interval
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
if max_replicas < min_replicas:
|
||||
raise ValueError(
|
||||
f"`max_replicas={max_replicas}` must be less than or equal to `min_replicas={min_replicas}`."
|
||||
)
|
||||
self.max_replicas = max_replicas
|
||||
self.min_replicas = min_replicas
|
||||
self._last_autoscale = time.time()
|
||||
self.fake_trigger = 0
|
||||
|
||||
self.load_balancer = _LoadBalancer(
|
||||
input_type=self._input_type,
|
||||
output_type=self._output_type,
|
||||
endpoint=endpoint,
|
||||
max_batch_size=max_batch_size,
|
||||
timeout_batching=timeout_batching,
|
||||
cache_calls=True,
|
||||
parallel=True,
|
||||
)
|
||||
for _ in range(min_replicas):
|
||||
work = self.create_work()
|
||||
self.add_work(work)
|
||||
|
||||
@property
|
||||
def workers(self) -> List[LightningWork]:
|
||||
return [self.get_work(i) for i in range(self.num_replicas)]
|
||||
|
||||
def create_work(self) -> LightningWork:
|
||||
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
|
||||
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
|
||||
return self._work_cls(*self._work_args, **self._work_kwargs, start_with_flow=False)
|
||||
|
||||
def add_work(self, work) -> str:
|
||||
"""Adds a new LightningWork instance.
|
||||
|
||||
Returns:
|
||||
The name of the new work attribute.
|
||||
"""
|
||||
work_attribute = uuid.uuid4().hex
|
||||
work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
|
||||
setattr(self, work_attribute, work)
|
||||
self._work_registry[self.num_replicas] = work_attribute
|
||||
self.num_replicas += 1
|
||||
return work_attribute
|
||||
|
||||
def remove_work(self, index: int) -> str:
|
||||
"""Removes the ``index`` th LightningWork instance."""
|
||||
work_attribute = self._work_registry[index]
|
||||
del self._work_registry[index]
|
||||
work = getattr(self, work_attribute)
|
||||
work.stop()
|
||||
self.num_replicas -= 1
|
||||
return work_attribute
|
||||
|
||||
def get_work(self, index: int) -> LightningWork:
|
||||
"""Returns the ``LightningWork`` instance with the given index."""
|
||||
work_attribute = self._work_registry[index]
|
||||
work = getattr(self, work_attribute)
|
||||
return work
|
||||
|
||||
def run(self):
|
||||
if not self.load_balancer.is_running:
|
||||
self.load_balancer.run()
|
||||
|
||||
for work in self.workers:
|
||||
work.run()
|
||||
|
||||
if self.load_balancer.url:
|
||||
self.fake_trigger += 1 # Note: change state to keep calling `run`.
|
||||
self.autoscale()
|
||||
|
||||
def scale(self, replicas: int, metrics: dict) -> int:
|
||||
"""The default scaling logic that users can override.
|
||||
|
||||
Args:
|
||||
replicas: The number of running works.
|
||||
metrics: ``metrics['pending_requests']`` is the total number of requests that are currently pending.
|
||||
``metrics['pending_works']`` is the number of pending works.
|
||||
|
||||
Returns:
|
||||
The target number of running works. The value will be adjusted after this method runs
|
||||
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
|
||||
"""
|
||||
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
|
||||
replicas + metrics["pending_works"]
|
||||
)
|
||||
|
||||
# scale out if the number of pending requests exceeds max batch size.
|
||||
max_requests_per_work = self.max_batch_size
|
||||
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
|
||||
return replicas + 1
|
||||
|
||||
# scale in if the number of pending requests is below 25% of max_requests_per_work
|
||||
min_requests_per_work = max_requests_per_work * 0.25
|
||||
if pending_requests_per_running_or_pending_work < min_requests_per_work:
|
||||
return replicas - 1
|
||||
|
||||
return replicas
|
||||
|
||||
@property
|
||||
def num_pending_requests(self) -> int:
|
||||
"""Fetches the number of pending requests via load balancer."""
|
||||
return int(requests.get(f"{self.load_balancer.url}/num-requests").json())
|
||||
|
||||
@property
|
||||
def num_pending_works(self) -> int:
|
||||
"""The number of pending works."""
|
||||
return sum(work.is_pending for work in self.workers)
|
||||
|
||||
def autoscale(self) -> None:
|
||||
"""Adjust the number of works based on the target number returned by ``self.scale``."""
|
||||
if time.time() - self._last_autoscale < self.autoscale_interval:
|
||||
return
|
||||
|
||||
self.load_balancer.update_servers(self.workers)
|
||||
|
||||
metrics = {
|
||||
"pending_requests": self.num_pending_requests,
|
||||
"pending_works": self.num_pending_works,
|
||||
}
|
||||
|
||||
# ensure min_replicas <= num_replicas <= max_replicas
|
||||
num_target_workers = max(
|
||||
self.min_replicas,
|
||||
min(self.max_replicas, self.scale(self.num_replicas, metrics)),
|
||||
)
|
||||
|
||||
# upscale
|
||||
num_workers_to_add = num_target_workers - self.num_replicas
|
||||
for _ in range(num_workers_to_add):
|
||||
logger.info(f"Upscaling from {self.num_replicas} to {self.num_replicas + 1}")
|
||||
work = self.create_work()
|
||||
new_work_id = self.add_work(work)
|
||||
logger.info(f"Work created: '{new_work_id}'")
|
||||
|
||||
# downscale
|
||||
num_workers_to_remove = self.num_replicas - num_target_workers
|
||||
for _ in range(num_workers_to_remove):
|
||||
logger.info(f"Downscaling from {self.num_replicas} to {self.num_replicas - 1}")
|
||||
removed_work_id = self.remove_work(self.num_replicas - 1)
|
||||
logger.info(f"Work removed: '{removed_work_id}'")
|
||||
|
||||
self.load_balancer.update_servers(self.workers)
|
||||
self._last_autoscale = time.time()
|
||||
|
||||
def configure_layout(self):
|
||||
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
|
||||
return tabs
|
|
@ -19,6 +19,8 @@ from lightning_app.utilities.packaging.build_config import BuildConfig
|
|||
|
||||
if _is_sqlmodel_available():
|
||||
from sqlmodel import SQLModel
|
||||
else:
|
||||
SQLModel = object
|
||||
|
||||
|
||||
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning_app import LightningWork
|
||||
from lightning_app.components import AutoScaler
|
||||
|
||||
|
||||
class EmptyWork(LightningWork):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoScaler1(AutoScaler):
|
||||
def scale(self, replicas: int, metrics) -> int:
|
||||
# only upscale
|
||||
return replicas + 1
|
||||
|
||||
|
||||
class AutoScaler2(AutoScaler):
|
||||
def scale(self, replicas: int, metrics) -> int:
|
||||
# only downscale
|
||||
return replicas - 1
|
||||
|
||||
|
||||
def test_num_replicas_after_init():
|
||||
"""Test the number of works is the same as min_replicas after initialization."""
|
||||
min_replicas = 2
|
||||
auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas)
|
||||
assert auto_scaler.num_replicas == min_replicas
|
||||
|
||||
|
||||
@patch("uvicorn.run")
|
||||
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
|
||||
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
|
||||
def test_num_replicas_not_above_max_replicas(*_):
|
||||
"""Test self.num_replicas doesn't exceed max_replicas."""
|
||||
max_replicas = 6
|
||||
auto_scaler = AutoScaler1(
|
||||
EmptyWork,
|
||||
min_replicas=1,
|
||||
max_replicas=max_replicas,
|
||||
autoscale_interval=0.001,
|
||||
)
|
||||
|
||||
for _ in range(max_replicas + 1):
|
||||
time.sleep(0.002)
|
||||
auto_scaler.run()
|
||||
|
||||
assert auto_scaler.num_replicas == max_replicas
|
||||
|
||||
|
||||
@patch("uvicorn.run")
|
||||
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
|
||||
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
|
||||
def test_num_replicas_not_belo_min_replicas(*_):
|
||||
"""Test self.num_replicas doesn't exceed max_replicas."""
|
||||
min_replicas = 1
|
||||
auto_scaler = AutoScaler2(
|
||||
EmptyWork,
|
||||
min_replicas=min_replicas,
|
||||
max_replicas=4,
|
||||
autoscale_interval=0.001,
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
time.sleep(0.002)
|
||||
auto_scaler.run()
|
||||
|
||||
assert auto_scaler.num_replicas == min_replicas
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"replicas, metrics, expected_replicas",
|
||||
[
|
||||
pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"),
|
||||
pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"),
|
||||
pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"),
|
||||
pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"),
|
||||
],
|
||||
)
|
||||
def test_scale(replicas, metrics, expected_replicas):
|
||||
"""Test `scale()`, the default scaling strategy."""
|
||||
auto_scaler = AutoScaler(
|
||||
EmptyWork,
|
||||
min_replicas=1,
|
||||
max_replicas=8,
|
||||
max_batch_size=1,
|
||||
)
|
||||
|
||||
assert auto_scaler.scale(replicas, metrics) == expected_replicas
|
Loading…
Reference in New Issue