[App] Introduce auto scaler (#15769)

* Exlucde __pycache__ in setuptools

* Add load balancer example

* wip

* Update example

* rename

* remove prints

* _LoadBalancer -> LoadBalancer

* AutoScaler(work)

* change var name

* remove locust

* Update docs

* include autoscaler in api ref

* docs typo

* docs typo

* docs typo

* docs typo

* remove unused loadtest

* remove unused device_type

* clean up

* clean up

* clean up

* Add docstring

* type

* env vars to args

* expose an API for users to override to customise autoscaling logic

* update example

* comment

* udpate var name

* fix scale mechanism and clean up

* Update exampl

* ignore mypy

* Add test file

* .

* update impl and update tests

* Update changlog

* .

* revert docs

* update test

* update state to keep calling 'flow.run()'

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>

* Add aiohttp to base requirements

* Update docs

Co-authored-by: Luca Antiga <luca.antiga@gmail.com>

* Use deserializer utility

* fake trigger

* wip: protect /system/* with basic auth

* read password at runtime

* Change env var name

* import torch as optional

* Don't overcreate works

* simplify imports

* Update example

* aiohttp

* Add work_args work_kwargs

* More docs

* remove FIXME

* Apply Jirka's suggestions

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean example device

* add comment on init threshold value

* bad merge

* nit: logging format

* {in,out}put_schema -> {in,out}put_type

* lowercase

* docs on seconds

* process_time -> processing_time

* Dont modify work state from flow

* Update tests

* worker_url -> endpoint

* fix exampl

* Fix default scale logic

* Fix default scale logic

* Fix num_pending_works

* Update num_pending_works

* Fix bug creating too many works

* Remove up/downscale_threshold args

* Update example

* Add typing

* Fix example in docstring

* Fix default scale logic

* Update src/lightning_app/components/auto_scaler.py

Co-authored-by: Noha Alon <nohalon@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename method

* rename locvar

* Add todo

* docs ci

* docs ci

* asdfafsdasdf pls docs

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ethanwharris@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* doc

* Update src/lightning_app/components/auto_scaler.py

Co-authored-by: Noha Alon <nohalon@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit 24983a0a5a.

* Revert "Update src/lightning_app/components/auto_scaler.py"

This reverts commit 56ea78b45f.

* Remove redefinition

* Remove load balancer run blocker

* raise RuntimeError

* remove has_sent

* lower the default timeout_batching from 10 to 1

* remove debug

* update the default timeout_batching

* .

* tighten condition

* fix endpoint

* typo in runtimeerror cond

* async lock update severs

* add a test

* {in,out}put_type typing

* Update examples/app_server_with_auto_scaler/app.py

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* Update .actions/setup_tools.py

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Noha Alon <nohalon@gmail.com>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Akihiro Nitta <aki@pop-os.localdomain>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Akihiro Nitta 2022-12-07 22:27:44 +09:00 committed by GitHub
parent 6aaac8b910
commit 64b19fb16f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 756 additions and 0 deletions

View File

@ -37,6 +37,7 @@ ___________________
~training.LightningTrainerScript
~serve.gradio.ServeGradio
~serve.serve.ModelInferenceAPI
~auto_scaler.AutoScaler
----

View File

@ -0,0 +1,86 @@
from typing import Any, List
import torch
import torchvision
from pydantic import BaseModel
import lightning as L
class RequestModel(BaseModel):
image: str # bytecode
class BatchRequestModel(BaseModel):
inputs: List[RequestModel]
class BatchResponse(BaseModel):
outputs: List[Any]
class PyTorchServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
port=L.app.utilities.network.find_free_network_port(),
input_type=BatchRequestModel,
output_type=BatchResponse,
cloud_compute=L.CloudCompute("gpu"),
)
def setup(self):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = torchvision.models.resnet18(pretrained=True).to(self._device)
def predict(self, requests: BatchRequestModel):
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
images = []
for request in requests.inputs:
image = L.app.components.serve.types.image.Image.deserialize(request.image)
image = transforms(image).unsqueeze(0)
images.append(image)
images = torch.cat(images)
images = images.to(self._device)
predictions = self._model(images)
results = predictions.argmax(1).cpu().numpy().tolist()
return BatchResponse(outputs=[{"prediction": pred} for pred in results])
class MyAutoScaler(L.app.components.AutoScaler):
def scale(self, replicas: int, metrics: dict) -> int:
"""The default scaling logic that users can override."""
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
return replicas + 1
# scale in if the number of pending requests is below 25% of max_requests_per_work
min_requests_per_work = max_requests_per_work * 0.25
pending_requests_per_running_work = metrics["pending_requests"] / replicas
if pending_requests_per_running_work < min_requests_per_work:
return replicas - 1
return replicas
app = L.LightningApp(
MyAutoScaler(
PyTorchServer,
min_replicas=2,
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
timeout_batching=1,
)
)

View File

@ -80,6 +80,7 @@ module = [
"lightning_app.components.serve.types.type",
"lightning_app.components.serve.python_server",
"lightning_app.components.training",
"lightning_app.components.auto_scaler",
"lightning_app.core.api",
"lightning_app.core.app",
"lightning_app.core.flow",

View File

@ -12,3 +12,4 @@ beautifulsoup4>=4.8.0, <4.11.2
inquirer>=2.10.0
psutil<5.9.4
click<=8.1.3
aiohttp>=3.8.0, <=3.8.3

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added the CLI command `lightning run model` to launch a `LightningLite` accelerated script ([#15506](https://github.com/Lightning-AI/lightning/pull/15506))
- Added the CLI command `lightning delete app` to delete a lightning app on the cloud ([#15783](https://github.com/Lightning-AI/lightning/pull/15783))
- Show a message when `BuildConfig(requirements=[...])` is passed but a `requirements.txt` file is already present in the Work ([#15799](https://github.com/Lightning-AI/lightning/pull/15799))
@ -17,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))
- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769))
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))

View File

@ -1,3 +1,4 @@
from lightning_app.components.auto_scaler import AutoScaler
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
@ -15,6 +16,7 @@ from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
__all__ = [
"AutoScaler",
"DatabaseClient",
"Database",
"PopenPythonScript",

View File

@ -0,0 +1,568 @@
import asyncio
import logging
import os
import secrets
import time
import uuid
from base64 import b64encode
from itertools import cycle
from typing import Any, Dict, List, Tuple, Type
import aiohttp
import aiohttp.client_exceptions
import requests
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel
from starlette.status import HTTP_401_UNAUTHORIZED
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
logger = Logger(__name__)
lock = asyncio.Lock()
def _raise_granular_exception(exception: Exception) -> None:
"""Handle an exception from hitting the model servers."""
if not isinstance(exception, Exception):
return
if isinstance(exception, HTTPException):
raise exception
if isinstance(exception, aiohttp.client_exceptions.ServerDisconnectedError):
raise HTTPException(500, "Worker Server Disconnected") from exception
if isinstance(exception, aiohttp.client_exceptions.ClientError):
logging.exception(exception)
raise HTTPException(500, "Worker Server error") from exception
if isinstance(exception, asyncio.TimeoutError):
raise HTTPException(408, "Request timed out") from exception
if isinstance(exception, Exception):
if exception.args[0] == "Server disconnected":
raise HTTPException(500, "Worker Server disconnected") from exception
logging.exception(exception)
raise HTTPException(500, exception.args[0]) from exception
class _SysInfo(BaseModel):
num_workers: int
servers: List[str]
num_requests: int
processing_time: int
global_request_count: int
class _BatchRequestModel(BaseModel):
inputs: List[Any]
def _create_fastapi(title: str) -> FastAPI:
fastapi_app = FastAPI(title=title)
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
fastapi_app.global_request_count = 0
fastapi_app.num_current_requests = 0
fastapi_app.last_processing_time = 0
@fastapi_app.get("/", include_in_schema=False)
async def docs():
return RedirectResponse("/docs")
@fastapi_app.get("/num-requests")
async def num_requests() -> int:
return fastapi_app.num_current_requests
return fastapi_app
class _LoadBalancer(LightningWork):
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
The LoadBalancer exposes system endpoints with a basic HTTP authentication, in order to activate the authentication
you need to provide a system password from environment variable::
lightning run app app.py --env AUTO_SCALER_AUTH_PASSWORD=PASSWORD
After enabling you will require to send username and password from the request header for the private endpoints.
Args:
input_type: Input type.
output_type: Output type.
endpoint: The REST API path.
max_batch_size: The number of requests processed at once.
timeout_batching: The number of seconds to wait before sending the requests to process in order to allow for
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
timeout_inference_request: The number of seconds to wait for inference.
\**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""
def __init__(
self,
input_type: BaseModel,
output_type: BaseModel,
endpoint: str,
max_batch_size: int = 8,
# all timeout args are in seconds
timeout_batching: int = 1,
timeout_keep_alive: int = 60,
timeout_inference_request: int = 60,
**kwargs: Any,
) -> None:
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
self._input_type = input_type
self._output_type = output_type
self._timeout_keep_alive = timeout_keep_alive
self._timeout_inference_request = timeout_inference_request
self.servers = []
self.max_batch_size = max_batch_size
self.timeout_batching = timeout_batching
self._iter = None
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self.endpoint = endpoint
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
server = next(self._iter) # round-robin
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)
try:
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
f"{server}{self.endpoint}",
json=batch_request_data.dict(),
timeout=self._timeout_inference_request,
headers=headers,
) as response:
if response.status == 408:
raise HTTPException(408, "Request timed out")
response.raise_for_status()
response = await response.json()
outputs = response["outputs"]
if len(batch) != len(outputs):
raise RuntimeError(f"result has {len(outputs)} items but batch is {len(batch)}")
result = {request[0]: r for request, r in zip(batch, outputs)}
self._responses.update(result)
except Exception as ex:
result = {request[0]: ex for request in batch}
self._responses.update(result)
async def consumer(self):
while True:
await asyncio.sleep(0.05)
batch = self._batch[: self.max_batch_size]
while batch and (
(len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching)
):
asyncio.create_task(self.send_batch(batch))
self._batch = self._batch[self.max_batch_size :]
batch = self._batch[: self.max_batch_size]
self._last_batch_sent = time.time()
async def process_request(self, data: BaseModel):
if not self.servers:
raise HTTPException(500, "None of the workers are healthy!")
request_id = uuid.uuid4().hex
request: Tuple = (request_id, data)
self._batch.append(request)
while True:
await asyncio.sleep(0.05)
if request_id in self._responses:
result = self._responses[request_id]
del self._responses[request_id]
_raise_granular_exception(result)
return result
def run(self):
logger.info(f"servers: {self.servers}")
self._iter = cycle(self.servers)
self._last_batch_sent = time.time()
fastapi_app = _create_fastapi("Load Balancer")
security = HTTPBasic()
fastapi_app.SEND_TASK = None
@fastapi_app.middleware("http")
async def current_request_counter(request: Request, call_next):
if not request.scope["path"] == self.endpoint:
return await call_next(request)
fastapi_app.global_request_count += 1
fastapi_app.num_current_requests += 1
start_time = time.time()
response = await call_next(request)
processing_time = time.time() - start_time
fastapi_app.last_processing_time = processing_time
fastapi_app.num_current_requests -= 1
return response
@fastapi_app.on_event("startup")
async def startup_event():
fastapi_app.SEND_TASK = asyncio.create_task(self.consumer())
@fastapi_app.on_event("shutdown")
def shutdown_event():
fastapi_app.SEND_TASK.cancel()
def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(security)):
AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "")
if len(AUTO_SCALER_AUTH_PASSWORD) == 0:
logger.warn(
"You have not set a password for private endpoints! To set a password, add "
"`--env AUTO_SCALER_AUTH_PASSWORD=<your pass>` to your lightning run command."
)
current_password_bytes = credentials.password.encode("utf8")
is_correct_password = secrets.compare_digest(
current_password_bytes, AUTO_SCALER_AUTH_PASSWORD.encode("utf8")
)
if not is_correct_password:
raise HTTPException(
status_code=401,
detail="Incorrect password",
headers={"WWW-Authenticate": "Basic"},
)
return True
@fastapi_app.get("/system/info", response_model=_SysInfo)
async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)):
return _SysInfo(
num_workers=len(self.servers),
servers=self.servers,
num_requests=fastapi_app.num_current_requests,
processing_time=fastapi_app.last_processing_time,
global_request_count=fastapi_app.global_request_count,
)
@fastapi_app.put("/system/update-servers")
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
self.servers = servers
self._iter = cycle(self.servers)
@fastapi_app.post(self.endpoint, response_model=self._output_type)
async def balance_api(inputs: self._input_type):
return await self.process_request(inputs)
uvicorn.run(
fastapi_app,
host=self.host,
port=self.port,
loop="uvloop",
timeout_keep_alive=self._timeout_keep_alive,
access_log=False,
)
def update_servers(self, server_works: List[LightningWork]):
"""Updates works that load balancer distributes requests to.
AutoScaler uses this method to increase/decrease the number of works.
"""
old_servers = set(self.servers)
server_urls: List[str] = [server.url for server in server_works if server.url]
new_servers = set(server_urls)
if new_servers == old_servers:
return
if new_servers - old_servers:
logger.info(f"servers added: {new_servers - old_servers}")
deleted_servers = old_servers - new_servers
if deleted_servers:
logger.info(f"servers deleted: {deleted_servers}")
self.send_request_to_update_servers(server_urls)
def send_request_to_update_servers(self, servers: List[str]):
AUTHORIZATION_TYPE = "Basic"
USERNAME = "lightning"
AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "")
try:
param = f"{USERNAME}:{AUTO_SCALER_AUTH_PASSWORD}".encode()
data = b64encode(param).decode("utf-8")
except (ValueError, UnicodeDecodeError) as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Basic"},
) from e
headers = {
"accept": "application/json",
"username": USERNAME,
"Authorization": AUTHORIZATION_TYPE + " " + data,
}
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
response.raise_for_status()
class AutoScaler(LightningFlow):
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
response to changes in the number of incoming requests. Incoming requests will be batched and balanced across
the replicas.
Args:
min_replicas: The number of works to start when app initializes.
max_replicas: The max number of works to spawn to handle the incoming requests.
autoscale_interval: The number of seconds to wait before checking whether to upscale or downscale the works.
endpoint: Provide the REST API path.
max_batch_size: (auto-batching) The number of requests to process at once.
timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
input_type: Input type.
output_type: Output type.
.. testcode::
import lightning as L
# Example 1: Auto-scaling serve component out-of-the-box
app = L.LightningApp(
L.app.components.AutoScaler(
MyPythonServer,
min_replicas=1,
max_replicas=8,
autoscale_interval=10,
)
)
# Example 2: Customizing the scaling logic
class MyAutoScaler(L.app.components.AutoScaler):
def scale(self, replicas: int, metrics: dict) -> int:
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
# upscale
max_requests_per_work = self.max_batch_size
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
return replicas + 1
# downscale
min_requests_per_work = max_requests_per_work * 0.25
if pending_requests_per_running_or_pending_work < min_requests_per_work:
return replicas - 1
return replicas
app = L.LightningApp(
MyAutoScaler(
MyPythonServer,
min_replicas=1,
max_replicas=8,
autoscale_interval=10,
max_batch_size=8, # for auto batching
timeout_batching=1, # for auto batching
)
)
"""
def __init__(
self,
work_cls: Type[LightningWork],
min_replicas: int = 1,
max_replicas: int = 4,
autoscale_interval: int = 10,
max_batch_size: int = 8,
timeout_batching: float = 1,
endpoint: str = "api/predict",
input_type: BaseModel = Dict,
output_type: BaseModel = Dict,
*work_args: Any,
**work_kwargs: Any,
) -> None:
super().__init__()
self.num_replicas = 0
self._work_registry = {}
self._work_cls = work_cls
self._work_args = work_args
self._work_kwargs = work_kwargs
self._input_type = input_type
self._output_type = output_type
self.autoscale_interval = autoscale_interval
self.max_batch_size = max_batch_size
if max_replicas < min_replicas:
raise ValueError(
f"`max_replicas={max_replicas}` must be less than or equal to `min_replicas={min_replicas}`."
)
self.max_replicas = max_replicas
self.min_replicas = min_replicas
self._last_autoscale = time.time()
self.fake_trigger = 0
self.load_balancer = _LoadBalancer(
input_type=self._input_type,
output_type=self._output_type,
endpoint=endpoint,
max_batch_size=max_batch_size,
timeout_batching=timeout_batching,
cache_calls=True,
parallel=True,
)
for _ in range(min_replicas):
work = self.create_work()
self.add_work(work)
@property
def workers(self) -> List[LightningWork]:
return [self.get_work(i) for i in range(self.num_replicas)]
def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
return self._work_cls(*self._work_args, **self._work_kwargs, start_with_flow=False)
def add_work(self, work) -> str:
"""Adds a new LightningWork instance.
Returns:
The name of the new work attribute.
"""
work_attribute = uuid.uuid4().hex
work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
setattr(self, work_attribute, work)
self._work_registry[self.num_replicas] = work_attribute
self.num_replicas += 1
return work_attribute
def remove_work(self, index: int) -> str:
"""Removes the ``index`` th LightningWork instance."""
work_attribute = self._work_registry[index]
del self._work_registry[index]
work = getattr(self, work_attribute)
work.stop()
self.num_replicas -= 1
return work_attribute
def get_work(self, index: int) -> LightningWork:
"""Returns the ``LightningWork`` instance with the given index."""
work_attribute = self._work_registry[index]
work = getattr(self, work_attribute)
return work
def run(self):
if not self.load_balancer.is_running:
self.load_balancer.run()
for work in self.workers:
work.run()
if self.load_balancer.url:
self.fake_trigger += 1 # Note: change state to keep calling `run`.
self.autoscale()
def scale(self, replicas: int, metrics: dict) -> int:
"""The default scaling logic that users can override.
Args:
replicas: The number of running works.
metrics: ``metrics['pending_requests']`` is the total number of requests that are currently pending.
``metrics['pending_works']`` is the number of pending works.
Returns:
The target number of running works. The value will be adjusted after this method runs
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
"""
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
return replicas + 1
# scale in if the number of pending requests is below 25% of max_requests_per_work
min_requests_per_work = max_requests_per_work * 0.25
if pending_requests_per_running_or_pending_work < min_requests_per_work:
return replicas - 1
return replicas
@property
def num_pending_requests(self) -> int:
"""Fetches the number of pending requests via load balancer."""
return int(requests.get(f"{self.load_balancer.url}/num-requests").json())
@property
def num_pending_works(self) -> int:
"""The number of pending works."""
return sum(work.is_pending for work in self.workers)
def autoscale(self) -> None:
"""Adjust the number of works based on the target number returned by ``self.scale``."""
if time.time() - self._last_autoscale < self.autoscale_interval:
return
self.load_balancer.update_servers(self.workers)
metrics = {
"pending_requests": self.num_pending_requests,
"pending_works": self.num_pending_works,
}
# ensure min_replicas <= num_replicas <= max_replicas
num_target_workers = max(
self.min_replicas,
min(self.max_replicas, self.scale(self.num_replicas, metrics)),
)
# upscale
num_workers_to_add = num_target_workers - self.num_replicas
for _ in range(num_workers_to_add):
logger.info(f"Upscaling from {self.num_replicas} to {self.num_replicas + 1}")
work = self.create_work()
new_work_id = self.add_work(work)
logger.info(f"Work created: '{new_work_id}'")
# downscale
num_workers_to_remove = self.num_replicas - num_target_workers
for _ in range(num_workers_to_remove):
logger.info(f"Downscaling from {self.num_replicas} to {self.num_replicas - 1}")
removed_work_id = self.remove_work(self.num_replicas - 1)
logger.info(f"Work removed: '{removed_work_id}'")
self.load_balancer.update_servers(self.workers)
self._last_autoscale = time.time()
def configure_layout(self):
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
return tabs

View File

@ -19,6 +19,8 @@ from lightning_app.utilities.packaging.build_config import BuildConfig
if _is_sqlmodel_available():
from sqlmodel import SQLModel
else:
SQLModel = object
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.

View File

@ -0,0 +1,92 @@
import time
from unittest.mock import patch
import pytest
from lightning_app import LightningWork
from lightning_app.components import AutoScaler
class EmptyWork(LightningWork):
def run(self):
pass
class AutoScaler1(AutoScaler):
def scale(self, replicas: int, metrics) -> int:
# only upscale
return replicas + 1
class AutoScaler2(AutoScaler):
def scale(self, replicas: int, metrics) -> int:
# only downscale
return replicas - 1
def test_num_replicas_after_init():
"""Test the number of works is the same as min_replicas after initialization."""
min_replicas = 2
auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas)
assert auto_scaler.num_replicas == min_replicas
@patch("uvicorn.run")
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
def test_num_replicas_not_above_max_replicas(*_):
"""Test self.num_replicas doesn't exceed max_replicas."""
max_replicas = 6
auto_scaler = AutoScaler1(
EmptyWork,
min_replicas=1,
max_replicas=max_replicas,
autoscale_interval=0.001,
)
for _ in range(max_replicas + 1):
time.sleep(0.002)
auto_scaler.run()
assert auto_scaler.num_replicas == max_replicas
@patch("uvicorn.run")
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
def test_num_replicas_not_belo_min_replicas(*_):
"""Test self.num_replicas doesn't exceed max_replicas."""
min_replicas = 1
auto_scaler = AutoScaler2(
EmptyWork,
min_replicas=min_replicas,
max_replicas=4,
autoscale_interval=0.001,
)
for _ in range(3):
time.sleep(0.002)
auto_scaler.run()
assert auto_scaler.num_replicas == min_replicas
@pytest.mark.parametrize(
"replicas, metrics, expected_replicas",
[
pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"),
pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"),
pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"),
pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"),
],
)
def test_scale(replicas, metrics, expected_replicas):
"""Test `scale()`, the default scaling strategy."""
auto_scaler = AutoScaler(
EmptyWork,
min_replicas=1,
max_replicas=8,
max_batch_size=1,
)
assert auto_scaler.scale(replicas, metrics) == expected_replicas