[App] Resolve PythonServer on M1 (#15949)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
thomas chaton 2022-12-08 12:31:42 +00:00 committed by GitHub
parent 36aecde695
commit 904323b5b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 56 deletions

View File

@ -1,2 +1,2 @@
streamlit>=1.3.1, <=1.11.1
streamlit>=1.0.0, <=1.15.2
panel>=0.12.7, <=0.13.1

View File

@ -56,17 +56,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))
- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))
- Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812)
- Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881))
- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911))
- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))
## [1.8.3] - 2022-11-22

View File

@ -206,7 +206,6 @@ class _LoadBalancer(LightningWork):
return result
def run(self):
logger.info(f"servers: {self.servers}")
lock = asyncio.Lock()
@ -271,7 +270,6 @@ class _LoadBalancer(LightningWork):
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
self.servers = servers
self._iter = cycle(self.servers)
@fastapi_app.post(self.endpoint, response_model=self._output_type)

View File

@ -22,6 +22,9 @@ class Code(TypedDict):
class TracerPythonScript(LightningWork):
_start_method = "spawn"
def on_before_run(self):
"""Called before the python script is executed."""

View File

@ -1,10 +1,8 @@
import abc
import os
from functools import partial
from types import ModuleType
from typing import Any, List, Optional
from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.imports import _is_gradio_available, requires
@ -36,15 +34,13 @@ class ServeGradio(LightningWork, abc.ABC):
title: Optional[str] = None
description: Optional[str] = None
_start_method = "spawn"
def __init__(self, *args, **kwargs):
requires("gradio")(super().__init__(*args, **kwargs))
assert self.inputs
assert self.outputs
self._model = None
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)
@property
def model(self):

View File

@ -1,19 +1,18 @@
import abc
import base64
import os
import platform
from pathlib import Path
from typing import Any, Dict, Optional
import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import module_available
from lightning_utilities.core.imports import compare_version, module_available
from pydantic import BaseModel
from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
logger = Logger(__name__)
@ -27,44 +26,19 @@ if not _is_torch_available():
__doctest_skip__ += ["PythonServer", "PythonServer.*"]
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
def _get_device():
import operator
"""This Executor enables to move PyTorch tensors on GPU.
import torch
Without this executor, it would raise the following exception:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
enable_start_observer: bool = False
local_rank = int(os.getenv("LOCAL_RANK", "0"))
def __call__(self, *args: Any, **kwargs: Any):
import torch
with self.enable_spawn():
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, args, kwargs),
nprocs=1,
)
@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)
state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)
unwrap(work.run)(*args, **kwargs)
if local_rank == 0:
state_observer.join(0)
if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
return torch.device("mps", local_rank)
else:
return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
class _DefaultInputData(BaseModel):
@ -95,6 +69,9 @@ class Number(BaseModel):
class PythonServer(LightningWork, abc.ABC):
_start_method = "spawn"
@requires(["torch", "lightning_api_access"])
def __init__( # type: ignore
self,
@ -160,11 +137,6 @@ class PythonServer(LightningWork, abc.ABC):
self._input_type = input_type
self._output_type = output_type
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)
def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
@ -210,13 +182,16 @@ class PythonServer(LightningWork, abc.ABC):
return out
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode
from torch import inference_mode, no_grad
input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()
device = _get_device()
context = no_grad if device.type == "mps" else inference_mode
def predict_fn(request: input_type): # type: ignore
with inference_mode():
with context():
return self.predict(request)
fastapi_app.post("/predict", response_model=output_type)(predict_fn)