diff --git a/src/lightning_app/components/serve/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py index 0c2b322498..4ba662603e 100644 --- a/src/lightning_app/components/serve/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -212,6 +212,8 @@ class _LoadBalancer(LightningWork): else: raise ValueError("cold_start_proxy must be of type ColdStartProxy or str") + self.ready = False + async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str): request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] batch_request_data = _BatchRequestModel(inputs=request_data) @@ -410,6 +412,7 @@ class _LoadBalancer(LightningWork): ) logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'") + self.ready = True uvicorn.run( fastapi_app, @@ -641,6 +644,10 @@ class AutoScaler(LightningFlow): def workers(self) -> List[LightningWork]: return [self.get_work(i) for i in range(self.num_replicas)] + @property + def ready(self) -> bool: + return self.load_balancer.ready + def create_work(self) -> LightningWork: """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" cloud_compute = self._work_kwargs.get("cloud_compute", None) diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 7c07129d39..29af372e03 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -42,6 +42,8 @@ class ServeGradio(LightningWork, abc.ABC): assert self.outputs self._model = None + self.ready = False + @property def model(self): return self._model @@ -62,6 +64,7 @@ class ServeGradio(LightningWork, abc.ABC): self._model = self.build_model() fn = partial(self.predict, *args, **kwargs) fn.__name__ = self.predict.__name__ + self.ready = True gradio.Interface( fn=fn, inputs=self.inputs, diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 6b7322c58c..55760bd06e 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -193,6 +193,8 @@ class PythonServer(LightningWork, abc.ABC): self._input_type = input_type self._output_type = output_type + self.ready = False + def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or initialize the weights, setting up pipelines etc. @@ -300,6 +302,7 @@ class PythonServer(LightningWork, abc.ABC): fastapi_app = FastAPI() self._attach_predict_fn(fastapi_app) + self.ready = True logger.info( f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}" ) diff --git a/src/lightning_app/components/serve/serve.py b/src/lightning_app/components/serve/serve.py index 8b6f35364c..50caca0079 100644 --- a/src/lightning_app/components/serve/serve.py +++ b/src/lightning_app/components/serve/serve.py @@ -64,6 +64,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC): self.workers = workers self._model = None + self.ready = False + @property def model(self): return self._model @@ -108,9 +110,11 @@ class ModelInferenceAPI(LightningWork, abc.ABC): "serve:fastapi_service", ] process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__)) + self.ready = True process.wait() else: self._populate_app(fastapi_service) + self.ready = True self._launch_server(fastapi_service) def _populate_app(self, fastapi_service: FastAPI): diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 5987425713..ee2931a6af 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -797,7 +797,7 @@ class _RootFlow(LightningFlow): @property def ready(self) -> bool: ready = getattr(self.work, "ready", None) - if ready: + if ready is not None: return ready return self.work.url != "" diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index ac671299bc..7e547d75c5 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -12,7 +12,7 @@ from deepdiff import DeepDiff, Delta import lightning_app from lightning_app import CloudCompute, LightningApp -from lightning_app.core.flow import LightningFlow +from lightning_app.core.flow import _RootFlow, LightningFlow from lightning_app.core.work import LightningWork from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path @@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works(): class WorkReady(LightningWork): def __init__(self): super().__init__(parallel=True) - self.counter = 0 + self.ready = False def run(self): - self.counter += 1 + self.ready = True class FlowReady(LightningFlow): @@ -890,7 +890,13 @@ class FlowReady(LightningFlow): self._exit() -def test_flow_ready(): +class RootFlowReady(_RootFlow): + def __init__(self): + super().__init__(WorkReady()) + + +@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady]) +def test_flow_ready(flow): """This test validates that the app status queue is populated correctly.""" mock_queue = _MockQueue("api_publish_state_queue") @@ -910,7 +916,7 @@ def test_flow_ready(): state["done"] = new_done return False - app = LightningApp(FlowReady()) + app = LightningApp(flow()) app._run = partial(run_patch, method=app._run) app.run_once = partial(lagged_run_once, method=app.run_once) MultiProcessRuntime(app, start_server=False).dispatch()