[App] Implement `ready` for components (#16129)

This commit is contained in:
Ethan Harris 2022-12-20 14:42:45 +00:00 committed by GitHub
parent ae14f9d1b3
commit 711aec5397
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 29 additions and 6 deletions

View File

@ -212,6 +212,8 @@ class _LoadBalancer(LightningWork):
else:
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
self.ready = False
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)
@ -410,6 +412,7 @@ class _LoadBalancer(LightningWork):
)
logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
self.ready = True
uvicorn.run(
fastapi_app,
@ -641,6 +644,10 @@ class AutoScaler(LightningFlow):
def workers(self) -> List[LightningWork]:
return [self.get_work(i) for i in range(self.num_replicas)]
@property
def ready(self) -> bool:
return self.load_balancer.ready
def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
cloud_compute = self._work_kwargs.get("cloud_compute", None)

View File

@ -42,6 +42,8 @@ class ServeGradio(LightningWork, abc.ABC):
assert self.outputs
self._model = None
self.ready = False
@property
def model(self):
return self._model
@ -62,6 +64,7 @@ class ServeGradio(LightningWork, abc.ABC):
self._model = self.build_model()
fn = partial(self.predict, *args, **kwargs)
fn.__name__ = self.predict.__name__
self.ready = True
gradio.Interface(
fn=fn,
inputs=self.inputs,

View File

@ -193,6 +193,8 @@ class PythonServer(LightningWork, abc.ABC):
self._input_type = input_type
self._output_type = output_type
self.ready = False
def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
@ -300,6 +302,7 @@ class PythonServer(LightningWork, abc.ABC):
fastapi_app = FastAPI()
self._attach_predict_fn(fastapi_app)
self.ready = True
logger.info(
f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
)

View File

@ -64,6 +64,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
self.workers = workers
self._model = None
self.ready = False
@property
def model(self):
return self._model
@ -108,9 +110,11 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
"serve:fastapi_service",
]
process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
self.ready = True
process.wait()
else:
self._populate_app(fastapi_service)
self.ready = True
self._launch_server(fastapi_service)
def _populate_app(self, fastapi_service: FastAPI):

View File

@ -797,7 +797,7 @@ class _RootFlow(LightningFlow):
@property
def ready(self) -> bool:
ready = getattr(self.work, "ready", None)
if ready:
if ready is not None:
return ready
return self.work.url != ""

View File

@ -12,7 +12,7 @@ from deepdiff import DeepDiff, Delta
import lightning_app
from lightning_app import CloudCompute, LightningApp
from lightning_app.core.flow import LightningFlow
from lightning_app.core.flow import _RootFlow, LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works():
class WorkReady(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.counter = 0
self.ready = False
def run(self):
self.counter += 1
self.ready = True
class FlowReady(LightningFlow):
@ -890,7 +890,13 @@ class FlowReady(LightningFlow):
self._exit()
def test_flow_ready():
class RootFlowReady(_RootFlow):
def __init__(self):
super().__init__(WorkReady())
@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
def test_flow_ready(flow):
"""This test validates that the app status queue is populated correctly."""
mock_queue = _MockQueue("api_publish_state_queue")
@ -910,7 +916,7 @@ def test_flow_ready():
state["done"] = new_done
return False
app = LightningApp(FlowReady())
app = LightningApp(flow())
app._run = partial(run_patch, method=app._run)
app.run_once = partial(lagged_run_once, method=app.run_once)
MultiProcessRuntime(app, start_server=False).dispatch()