[App] Implement `ready` for components (#16129)
This commit is contained in:
parent
ae14f9d1b3
commit
711aec5397
|
@ -212,6 +212,8 @@ class _LoadBalancer(LightningWork):
|
|||
else:
|
||||
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
|
||||
|
||||
self.ready = False
|
||||
|
||||
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
|
||||
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
|
||||
batch_request_data = _BatchRequestModel(inputs=request_data)
|
||||
|
@ -410,6 +412,7 @@ class _LoadBalancer(LightningWork):
|
|||
)
|
||||
|
||||
logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
|
||||
self.ready = True
|
||||
|
||||
uvicorn.run(
|
||||
fastapi_app,
|
||||
|
@ -641,6 +644,10 @@ class AutoScaler(LightningFlow):
|
|||
def workers(self) -> List[LightningWork]:
|
||||
return [self.get_work(i) for i in range(self.num_replicas)]
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
return self.load_balancer.ready
|
||||
|
||||
def create_work(self) -> LightningWork:
|
||||
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
|
||||
cloud_compute = self._work_kwargs.get("cloud_compute", None)
|
||||
|
|
|
@ -42,6 +42,8 @@ class ServeGradio(LightningWork, abc.ABC):
|
|||
assert self.outputs
|
||||
self._model = None
|
||||
|
||||
self.ready = False
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
||||
|
@ -62,6 +64,7 @@ class ServeGradio(LightningWork, abc.ABC):
|
|||
self._model = self.build_model()
|
||||
fn = partial(self.predict, *args, **kwargs)
|
||||
fn.__name__ = self.predict.__name__
|
||||
self.ready = True
|
||||
gradio.Interface(
|
||||
fn=fn,
|
||||
inputs=self.inputs,
|
||||
|
|
|
@ -193,6 +193,8 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
self._input_type = input_type
|
||||
self._output_type = output_type
|
||||
|
||||
self.ready = False
|
||||
|
||||
def setup(self, *args, **kwargs) -> None:
|
||||
"""This method is called before the server starts. Override this if you need to download the model or
|
||||
initialize the weights, setting up pipelines etc.
|
||||
|
@ -300,6 +302,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
fastapi_app = FastAPI()
|
||||
self._attach_predict_fn(fastapi_app)
|
||||
|
||||
self.ready = True
|
||||
logger.info(
|
||||
f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
|
||||
)
|
||||
|
|
|
@ -64,6 +64,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
|
|||
self.workers = workers
|
||||
self._model = None
|
||||
|
||||
self.ready = False
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
||||
|
@ -108,9 +110,11 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
|
|||
"serve:fastapi_service",
|
||||
]
|
||||
process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
|
||||
self.ready = True
|
||||
process.wait()
|
||||
else:
|
||||
self._populate_app(fastapi_service)
|
||||
self.ready = True
|
||||
self._launch_server(fastapi_service)
|
||||
|
||||
def _populate_app(self, fastapi_service: FastAPI):
|
||||
|
|
|
@ -797,7 +797,7 @@ class _RootFlow(LightningFlow):
|
|||
@property
|
||||
def ready(self) -> bool:
|
||||
ready = getattr(self.work, "ready", None)
|
||||
if ready:
|
||||
if ready is not None:
|
||||
return ready
|
||||
return self.work.url != ""
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from deepdiff import DeepDiff, Delta
|
|||
|
||||
import lightning_app
|
||||
from lightning_app import CloudCompute, LightningApp
|
||||
from lightning_app.core.flow import LightningFlow
|
||||
from lightning_app.core.flow import _RootFlow, LightningFlow
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage import Path
|
||||
|
@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works():
|
|||
class WorkReady(LightningWork):
|
||||
def __init__(self):
|
||||
super().__init__(parallel=True)
|
||||
self.counter = 0
|
||||
self.ready = False
|
||||
|
||||
def run(self):
|
||||
self.counter += 1
|
||||
self.ready = True
|
||||
|
||||
|
||||
class FlowReady(LightningFlow):
|
||||
|
@ -890,7 +890,13 @@ class FlowReady(LightningFlow):
|
|||
self._exit()
|
||||
|
||||
|
||||
def test_flow_ready():
|
||||
class RootFlowReady(_RootFlow):
|
||||
def __init__(self):
|
||||
super().__init__(WorkReady())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
|
||||
def test_flow_ready(flow):
|
||||
"""This test validates that the app status queue is populated correctly."""
|
||||
|
||||
mock_queue = _MockQueue("api_publish_state_queue")
|
||||
|
@ -910,7 +916,7 @@ def test_flow_ready():
|
|||
state["done"] = new_done
|
||||
return False
|
||||
|
||||
app = LightningApp(FlowReady())
|
||||
app = LightningApp(flow())
|
||||
app._run = partial(run_patch, method=app._run)
|
||||
app.run_once = partial(lagged_run_once, method=app.run_once)
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
|
Loading…
Reference in New Issue