[App] Remove `SingleProcessRuntime` (#15933)
* Remove SingleProcessRuntime * Remove unused queues * Docs
This commit is contained in:
parent
06163e6db5
commit
e250dfe2b3
|
@ -18,5 +18,4 @@ ______________
|
|||
:template: classtemplate.rst
|
||||
|
||||
~cloud.CloudRuntime
|
||||
~singleprocess.SingleProcessRuntime
|
||||
~multiprocess.MultiProcessRuntime
|
||||
|
|
|
@ -89,7 +89,6 @@ _______
|
|||
:template: classtemplate_no_index.rst
|
||||
|
||||
~cloud.CloudRuntime
|
||||
~singleprocess.SingleProcessRuntime
|
||||
~multiprocess.MultiProcessRuntime
|
||||
|
||||
----
|
||||
|
|
|
@ -120,7 +120,6 @@ We provide ``application_testing`` as a helper funtion to get your application u
|
|||
os.path.join(_PROJECT_ROOT, "examples/app_v0/app.py"),
|
||||
"--blocking",
|
||||
"False",
|
||||
"--multiprocess",
|
||||
"--open-ui",
|
||||
"False",
|
||||
]
|
||||
|
@ -129,9 +128,7 @@ First in the list for ``command_line`` is the location of your script. It is an
|
|||
|
||||
Next there are a couple of options you can leverage:
|
||||
|
||||
|
||||
* ``blocking`` - Blocking is an app status that says "Do not run until I click run in the UI". For our integration test, since we are not using the UI, we are setting this to "False".
|
||||
* ``multiprocess/singleprocess`` - This is the runtime your app is expected to run under.
|
||||
* ``open-ui`` - We set this to false since this is the routine that opens a browser for your local execution.
|
||||
|
||||
Once you have your commandline ready, you will then be able to kick off the test and gather results:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
|
||||
from lightning_app import LightningApp, LightningFlow
|
||||
from lightning_app.frontend import StreamlitFrontend
|
||||
from lightning_app.utilities.state import AppState
|
||||
from lightning.app import LightningApp, LightningFlow
|
||||
from lightning.app.frontend import StreamlitFrontend
|
||||
from lightning.app.utilities.state import AppState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,4 +45,4 @@ class HelloWorld(LightningFlow):
|
|||
return [{"name": "StreamLitUI", "content": self.streamlit_ui}]
|
||||
|
||||
|
||||
app = LightningApp(HelloWorld(), log_level="debug")
|
||||
app = LightningApp(HelloWorld())
|
||||
|
|
|
@ -102,7 +102,6 @@ module = [
|
|||
"lightning_app.runners.cloud",
|
||||
"lightning_app.runners.multiprocess",
|
||||
"lightning_app.runners.runtime",
|
||||
"lightning_app.runners.singleprocess",
|
||||
"lightning_app.source_code.copytree",
|
||||
"lightning_app.source_code.hashing",
|
||||
"lightning_app.source_code.local",
|
||||
|
|
|
@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
-
|
||||
- Removed the `SingleProcessRuntime` ([#15933](https://github.com/Lightning-AI/lightning/pull/15933))
|
||||
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -25,7 +25,6 @@ from lightning_app.utilities.app_helpers import Logger
|
|||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
logger = Logger(__name__)
|
||||
lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _raise_granular_exception(exception: Exception) -> None:
|
||||
|
@ -209,6 +208,7 @@ class _LoadBalancer(LightningWork):
|
|||
def run(self):
|
||||
|
||||
logger.info(f"servers: {self.servers}")
|
||||
lock = asyncio.Lock()
|
||||
|
||||
self._iter = cycle(self.servers)
|
||||
self._last_batch_sent = time.time()
|
||||
|
|
|
@ -21,7 +21,7 @@ from lightning_app.core.constants import (
|
|||
FRONTEND_DIR,
|
||||
STATE_ACCUMULATE_WAIT,
|
||||
)
|
||||
from lightning_app.core.queues import BaseQueue, SingleProcessQueue
|
||||
from lightning_app.core.queues import BaseQueue
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.frontend import Frontend
|
||||
from lightning_app.storage import Drive, Path, Payload
|
||||
|
@ -549,8 +549,6 @@ class LightningApp:
|
|||
def _should_snapshot(self) -> bool:
|
||||
if len(self.works) == 0:
|
||||
return True
|
||||
elif isinstance(self.delta_queue, SingleProcessQueue):
|
||||
return True
|
||||
elif self._has_updated:
|
||||
work_finished_status = self._collect_work_finish_status()
|
||||
if work_finished_status:
|
||||
|
|
|
@ -49,7 +49,6 @@ FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT = "FLOW_TO_WORKS_DELTA_QUEUE"
|
|||
|
||||
|
||||
class QueuingSystem(Enum):
|
||||
SINGLEPROCESS = "singleprocess"
|
||||
MULTIPROCESS = "multiprocess"
|
||||
REDIS = "redis"
|
||||
HTTP = "http"
|
||||
|
@ -59,10 +58,8 @@ class QueuingSystem(Enum):
|
|||
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
elif self == QueuingSystem.REDIS:
|
||||
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
|
||||
elif self == QueuingSystem.HTTP:
|
||||
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
else:
|
||||
return SingleProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
|
||||
|
||||
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
|
||||
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
|
||||
|
@ -179,21 +176,6 @@ class BaseQueue(ABC):
|
|||
return True
|
||||
|
||||
|
||||
class SingleProcessQueue(BaseQueue):
|
||||
def __init__(self, name: str, default_timeout: float):
|
||||
self.name = name
|
||||
self.default_timeout = default_timeout
|
||||
self.queue = queue.Queue()
|
||||
|
||||
def put(self, item):
|
||||
self.queue.put(item)
|
||||
|
||||
def get(self, timeout: int = None):
|
||||
if timeout == 0:
|
||||
timeout = self.default_timeout
|
||||
return self.queue.get(timeout=timeout, block=(timeout is None))
|
||||
|
||||
|
||||
class MultiProcessQueue(BaseQueue):
|
||||
def __init__(self, name: str, default_timeout: float):
|
||||
self.name = name
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from lightning_app.runners.cloud import CloudRuntime
|
||||
from lightning_app.runners.multiprocess import MultiProcessRuntime
|
||||
from lightning_app.runners.runtime import dispatch, Runtime
|
||||
from lightning_app.runners.singleprocess import SingleProcessRuntime
|
||||
from lightning_app.utilities.app_commands import run_app_commands
|
||||
from lightning_app.utilities.load_app import load_app_from_file
|
||||
|
||||
|
@ -11,6 +10,5 @@ __all__ = [
|
|||
"run_app_commands",
|
||||
"Runtime",
|
||||
"MultiProcessRuntime",
|
||||
"SingleProcessRuntime",
|
||||
"CloudRuntime",
|
||||
]
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
from enum import Enum
|
||||
from typing import Type, TYPE_CHECKING
|
||||
|
||||
from lightning_app.runners import CloudRuntime, MultiProcessRuntime, SingleProcessRuntime
|
||||
from lightning_app.runners import CloudRuntime, MultiProcessRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lightning_app.runners.runtime import Runtime
|
||||
|
||||
|
||||
class RuntimeType(Enum):
|
||||
SINGLEPROCESS = "singleprocess"
|
||||
MULTIPROCESS = "multiprocess"
|
||||
CLOUD = "cloud"
|
||||
|
||||
def get_runtime(self) -> Type["Runtime"]:
|
||||
if self == RuntimeType.SINGLEPROCESS:
|
||||
return SingleProcessRuntime
|
||||
elif self == RuntimeType.MULTIPROCESS:
|
||||
if self == RuntimeType.MULTIPROCESS:
|
||||
return MultiProcessRuntime
|
||||
elif self == RuntimeType.CLOUD:
|
||||
return CloudRuntime
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
from lightning_app.core.api import start_server
|
||||
from lightning_app.core.queues import QueuingSystem
|
||||
from lightning_app.runners.runtime import Runtime
|
||||
from lightning_app.utilities.app_helpers import _is_headless
|
||||
from lightning_app.utilities.load_app import extract_metadata_from_app
|
||||
|
||||
|
||||
class SingleProcessRuntime(Runtime):
|
||||
"""Runtime to launch the LightningApp into a single process."""
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
def dispatch(self, *args, open_ui: bool = True, **kwargs: Any):
|
||||
"""Method to dispatch and run the LightningApp."""
|
||||
queue = QueuingSystem.SINGLEPROCESS
|
||||
|
||||
self.app.delta_queue = queue.get_delta_queue()
|
||||
self.app.state_update_queue = queue.get_caller_queue(work_name="single_worker")
|
||||
self.app.error_queue = queue.get_error_queue()
|
||||
|
||||
if self.start_server:
|
||||
self.app.should_publish_changes_to_api = True
|
||||
self.app.api_publish_state_queue = QueuingSystem.MULTIPROCESS.get_api_state_publish_queue()
|
||||
self.app.api_delta_queue = QueuingSystem.MULTIPROCESS.get_api_delta_queue()
|
||||
has_started_queue = QueuingSystem.MULTIPROCESS.get_has_server_started_queue()
|
||||
kwargs = dict(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
api_publish_state_queue=self.app.api_publish_state_queue,
|
||||
api_delta_queue=self.app.api_delta_queue,
|
||||
has_started_queue=has_started_queue,
|
||||
spec=extract_metadata_from_app(self.app),
|
||||
root_path=self.app.root_path,
|
||||
)
|
||||
server_proc = mp.Process(target=start_server, kwargs=kwargs)
|
||||
self.processes["server"] = server_proc
|
||||
server_proc.start()
|
||||
|
||||
# wait for server to be ready.
|
||||
has_started_queue.get()
|
||||
|
||||
if open_ui and not _is_headless(self.app):
|
||||
click.launch(self._get_app_url())
|
||||
|
||||
try:
|
||||
self.app._run()
|
||||
except KeyboardInterrupt:
|
||||
self.terminate()
|
||||
raise
|
||||
finally:
|
||||
self.terminate()
|
||||
|
||||
@staticmethod
|
||||
def _get_app_url() -> str:
|
||||
return os.getenv("APP_SERVER_HOST", "http://127.0.0.1:7501/view")
|
|
@ -130,13 +130,6 @@ class InMemoryStateStore(StateStore):
|
|||
self.store[k].session_id = v
|
||||
|
||||
|
||||
class DistributedMode(enum.Enum):
|
||||
SINGLEPROCESS = enum.auto()
|
||||
MULTIPROCESS = enum.auto()
|
||||
CONTAINER = enum.auto()
|
||||
GRID = enum.auto()
|
||||
|
||||
|
||||
class _LightningAppRef:
|
||||
_app_instance: Optional["LightningApp"] = None
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import enum
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from deepdiff import DeepDiff
|
||||
|
@ -149,16 +150,26 @@ class AppState:
|
|||
return
|
||||
app_url = f"{self._url}/api/v1/state"
|
||||
headers = headers_for(self._plugin.get_context()) if self._plugin else {}
|
||||
try:
|
||||
response = self._session.get(app_url, headers=headers, timeout=1)
|
||||
except ConnectionError as e:
|
||||
raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e
|
||||
|
||||
self._authorized = response.status_code
|
||||
if self._authorized != 200:
|
||||
return
|
||||
logger.debug(f"GET STATE {response} {response.json()}")
|
||||
self._store_state(response.json())
|
||||
response_json = {}
|
||||
|
||||
# Sometimes the state URL can return an empty JSON when things are being set-up,
|
||||
# so we wait for it to be ready here.
|
||||
while response_json == {}:
|
||||
sleep(0.5)
|
||||
try:
|
||||
response = self._session.get(app_url, headers=headers, timeout=1)
|
||||
except ConnectionError as e:
|
||||
raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e
|
||||
|
||||
self._authorized = response.status_code
|
||||
if self._authorized != 200:
|
||||
return
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
logger.debug(f"GET STATE {response} {response_json}")
|
||||
self._store_state(response_json)
|
||||
|
||||
def __getattr__(self, name: str) -> Union[Any, "AppState"]:
|
||||
if name in self._APP_PRIVATE_KEYS:
|
||||
|
|
|
@ -28,7 +28,7 @@ from lightning_app.core.api import (
|
|||
UIRefresher,
|
||||
)
|
||||
from lightning_app.core.constants import APP_SERVER_PORT
|
||||
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage.drive import Drive
|
||||
from lightning_app.testing.helpers import _MockQueue
|
||||
from lightning_app.utilities.component import _set_frontend_context, _set_work_context
|
||||
|
@ -71,12 +71,10 @@ class _A(LightningFlow):
|
|||
self.work_a.run()
|
||||
|
||||
|
||||
# TODO: Resolve singleprocess - idea: explore frame calls recursively.
|
||||
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
|
||||
def test_app_state_api(runtime_cls):
|
||||
def test_app_state_api():
|
||||
"""This test validates the AppState can properly broadcast changes from work within its own process."""
|
||||
app = LightningApp(_A(), log_level="debug")
|
||||
runtime_cls(app, start_server=True).dispatch()
|
||||
MultiProcessRuntime(app, start_server=True).dispatch()
|
||||
assert app.root.work_a.var_a == -1
|
||||
_set_work_context()
|
||||
assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"]
|
||||
|
@ -105,13 +103,10 @@ class A2(LightningFlow):
|
|||
self._exit()
|
||||
|
||||
|
||||
# TODO: Find why this test is flaky.
|
||||
@pytest.mark.skip(reason="flaky test.")
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
|
||||
def test_app_state_api_with_flows(runtime_cls, tmpdir):
|
||||
def test_app_state_api_with_flows(tmpdir):
|
||||
"""This test validates the AppState can properly broadcast changes from flows."""
|
||||
app = LightningApp(A2(), log_level="debug")
|
||||
runtime_cls(app, start_server=True).dispatch()
|
||||
MultiProcessRuntime(app, start_server=True).dispatch()
|
||||
assert app.root.var_a == -1
|
||||
|
||||
|
||||
|
@ -181,13 +176,12 @@ class AppStageTestingApp(LightningApp):
|
|||
|
||||
# FIXME: This test doesn't assert anything
|
||||
@pytest.mark.skip(reason="TODO: Resolve flaky test.")
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
|
||||
def test_app_stage_from_frontend(runtime_cls):
|
||||
def test_app_stage_from_frontend():
|
||||
"""This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would
|
||||
start and stop the app."""
|
||||
app = AppStageTestingApp(FlowA(), log_level="debug")
|
||||
app.stage = AppStage.BLOCKING
|
||||
runtime_cls(app, start_server=True).dispatch()
|
||||
MultiProcessRuntime(app, start_server=True).dispatch()
|
||||
|
||||
|
||||
def test_update_publish_state_and_maybe_refresh_ui():
|
||||
|
|
|
@ -4,7 +4,6 @@ import pickle
|
|||
from re import escape
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
from deepdiff import Delta
|
||||
|
@ -19,9 +18,9 @@ from lightning_app.core.constants import (
|
|||
REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
|
||||
STATE_UPDATE_TIMEOUT,
|
||||
)
|
||||
from lightning_app.core.queues import BaseQueue, MultiProcessQueue, RedisQueue, SingleProcessQueue
|
||||
from lightning_app.core.queues import BaseQueue, MultiProcessQueue, RedisQueue
|
||||
from lightning_app.frontend import StreamlitFrontend
|
||||
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage import Path
|
||||
from lightning_app.storage.path import _storage_root_dir
|
||||
from lightning_app.testing.helpers import _RunIf
|
||||
|
@ -82,7 +81,7 @@ class Work(LightningWork):
|
|||
self.has_finished = False
|
||||
|
||||
def run(self):
|
||||
self.counter += 1
|
||||
self.counter = self.counter + 1
|
||||
if self.cache_calls:
|
||||
self.has_finished = True
|
||||
elif self.counter >= 3:
|
||||
|
@ -96,40 +95,60 @@ class SimpleFlow(LightningFlow):
|
|||
self.work_b = Work(cache_calls=False)
|
||||
|
||||
def run(self):
|
||||
self.work_a.run()
|
||||
self.work_b.run()
|
||||
if self.work_a.has_finished and self.work_b.has_finished:
|
||||
self._exit()
|
||||
self.work_a.run()
|
||||
self.work_b.run()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize("component_cls", [SimpleFlow])
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
|
||||
def test_simple_app(component_cls, runtime_cls, tmpdir):
|
||||
comp = component_cls()
|
||||
def test_simple_app(tmpdir):
|
||||
comp = SimpleFlow()
|
||||
app = LightningApp(comp, log_level="debug")
|
||||
assert app.root == comp
|
||||
expected = {
|
||||
"app_state": ANY,
|
||||
"vars": {"_layout": ANY, "_paths": {}},
|
||||
"app_state": mock.ANY,
|
||||
"vars": {"_layout": mock.ANY, "_paths": {}},
|
||||
"calls": {},
|
||||
"flows": {},
|
||||
"structures": {},
|
||||
"works": {
|
||||
"work_b": {
|
||||
"vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}},
|
||||
"calls": {},
|
||||
"vars": {
|
||||
"has_finished": False,
|
||||
"counter": 0,
|
||||
"_cloud_compute": mock.ANY,
|
||||
"_host": mock.ANY,
|
||||
"_url": "",
|
||||
"_future_url": "",
|
||||
"_internal_ip": "",
|
||||
"_paths": {},
|
||||
"_port": None,
|
||||
"_restarting": False,
|
||||
},
|
||||
"calls": {"latest_call_hash": None},
|
||||
"changes": {},
|
||||
},
|
||||
"work_a": {
|
||||
"vars": {"has_finished": False, "counter": 0, "_urls": {}, "_paths": {}},
|
||||
"calls": {},
|
||||
"vars": {
|
||||
"has_finished": False,
|
||||
"counter": 0,
|
||||
"_cloud_compute": mock.ANY,
|
||||
"_host": mock.ANY,
|
||||
"_url": "",
|
||||
"_future_url": "",
|
||||
"_internal_ip": "",
|
||||
"_paths": {},
|
||||
"_port": None,
|
||||
"_restarting": False,
|
||||
},
|
||||
"calls": {"latest_call_hash": None},
|
||||
"changes": {},
|
||||
},
|
||||
},
|
||||
"changes": {},
|
||||
}
|
||||
assert app.state == expected
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
assert comp.work_a.has_finished
|
||||
assert comp.work_b.has_finished
|
||||
|
@ -357,11 +376,10 @@ class SimpleApp2(LightningApp):
|
|||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
|
||||
def test_app_restarting_move_to_blocking(runtime_cls, tmpdir):
|
||||
def test_app_restarting_move_to_blocking(tmpdir):
|
||||
"""Validates sending restarting move the app to blocking again."""
|
||||
app = SimpleApp2(CounterFlow(), log_level="debug")
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
|
||||
class FlowWithFrontend(LightningFlow):
|
||||
|
@ -411,7 +429,6 @@ class EmptyFlow(LightningFlow):
|
|||
@pytest.mark.parametrize(
|
||||
"queue_type_cls, default_timeout",
|
||||
[
|
||||
(SingleProcessQueue, STATE_UPDATE_TIMEOUT),
|
||||
(MultiProcessQueue, STATE_UPDATE_TIMEOUT),
|
||||
pytest.param(
|
||||
RedisQueue,
|
||||
|
@ -463,7 +480,7 @@ def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQu
|
|||
assert generated > expect
|
||||
|
||||
|
||||
class SimpleFlow(LightningFlow):
|
||||
class SimpleFlow2(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
|
@ -476,8 +493,8 @@ class SimpleFlow(LightningFlow):
|
|||
def test_maybe_apply_changes_from_flow():
|
||||
"""This test validates the app `_updated` is set to True only if the state was changed in the flow."""
|
||||
|
||||
app = LightningApp(SimpleFlow())
|
||||
app.delta_queue = SingleProcessQueue("a", 0)
|
||||
app = LightningApp(SimpleFlow2())
|
||||
app.delta_queue = MultiProcessQueue("a", 0)
|
||||
assert app._has_updated
|
||||
app.maybe_apply_changes()
|
||||
app.root.run()
|
||||
|
|
|
@ -13,7 +13,7 @@ from deepdiff import DeepDiff, Delta
|
|||
from lightning_app import LightningApp
|
||||
from lightning_app.core.flow import LightningFlow
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage import Path
|
||||
from lightning_app.storage.path import _storage_root_dir
|
||||
from lightning_app.structures import Dict as LDict
|
||||
|
@ -237,7 +237,7 @@ def _run_state_transformation(tmpdir, attribute, update_fn, inplace=False):
|
|||
flow = StateTransformationTest()
|
||||
assert flow.x == attribute
|
||||
app = LightningApp(flow)
|
||||
SingleProcessRuntime(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
return app.state["vars"]["x"]
|
||||
|
||||
|
||||
|
@ -519,11 +519,10 @@ class CFlow(LightningFlow):
|
|||
self._exit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
|
||||
@pytest.mark.parametrize("run_once", [False, True])
|
||||
def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once):
|
||||
def test_lightning_flow_iterate(tmpdir, run_once):
|
||||
app = LightningApp(CFlow(run_once))
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.looping == 0
|
||||
assert app.root.tracker == 4
|
||||
call_hash = list(v for v in app.root._calls if "experimental_iterate" in v)[0]
|
||||
|
@ -537,7 +536,7 @@ def test_lightning_flow_iterate(tmpdir, runtime_cls, run_once):
|
|||
app.root.restarting = True
|
||||
assert app.root.looping == 0
|
||||
assert app.root.tracker == 4
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.looping == 2
|
||||
assert app.root.tracker == 10 if run_once else 20
|
||||
iterate_call = app.root._calls[call_hash]
|
||||
|
@ -555,12 +554,11 @@ class FlowCounter(LightningFlow):
|
|||
self.counter += 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
|
||||
def test_lightning_flow_counter(runtime_cls, tmpdir):
|
||||
def test_lightning_flow_counter(tmpdir):
|
||||
|
||||
app = LightningApp(FlowCounter())
|
||||
app.checkpointing = True
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.counter == 3
|
||||
|
||||
checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints")
|
||||
|
@ -571,7 +569,7 @@ def test_lightning_flow_counter(runtime_cls, tmpdir):
|
|||
with open(checkpoint_path, "rb") as f:
|
||||
app = LightningApp(FlowCounter())
|
||||
app.set_state(pickle.load(f))
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.counter == 3
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from lightning_app.utilities.redis import check_if_redis_running
|
|||
|
||||
|
||||
@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
|
||||
@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS, QueuingSystem.SINGLEPROCESS])
|
||||
@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS])
|
||||
def test_queue_api(queue_type, monkeypatch):
|
||||
"""Test the Queue API.
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ from lightning_app.runners.runtime_type import RuntimeType
|
|||
@pytest.mark.parametrize(
|
||||
"runtime_type",
|
||||
[
|
||||
RuntimeType.SINGLEPROCESS,
|
||||
RuntimeType.MULTIPROCESS,
|
||||
RuntimeType.CLOUD,
|
||||
],
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning_app import LightningFlow
|
||||
from lightning_app.core.app import LightningApp
|
||||
from lightning_app.runners import SingleProcessRuntime
|
||||
|
||||
|
||||
class Flow(LightningFlow):
|
||||
def run(self):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def on_before_run():
|
||||
pass
|
||||
|
||||
|
||||
def test_single_process_runtime(tmpdir):
|
||||
|
||||
app = LightningApp(Flow())
|
||||
SingleProcessRuntime(app, start_server=False).dispatch(on_before_run=on_before_run)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env,expected_url",
|
||||
[
|
||||
({}, "http://127.0.0.1:7501/view"),
|
||||
({"APP_SERVER_HOST": "http://test"}, "http://test"),
|
||||
],
|
||||
)
|
||||
def test_get_app_url(env, expected_url):
|
||||
with mock.patch.dict(os.environ, env):
|
||||
assert SingleProcessRuntime._get_app_url() == expected_url
|
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||
import pytest
|
||||
|
||||
from lightning_app import LightningApp, LightningFlow, LightningWork
|
||||
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.storage.payload import Payload
|
||||
from lightning_app.structures import Dict, List
|
||||
from lightning_app.testing.helpers import EmptyFlow
|
||||
|
@ -309,11 +309,10 @@ class CounterWork(LightningWork):
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="tchaton: Resolve this test.")
|
||||
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime])
|
||||
@pytest.mark.parametrize("run_once_iterable", [False, True])
|
||||
@pytest.mark.parametrize("cache_calls", [False, True])
|
||||
@pytest.mark.parametrize("use_list", [False, True])
|
||||
def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterable, cache_calls, use_list):
|
||||
def test_structure_with_iterate_and_fault_tolerance(run_once_iterable, cache_calls, use_list):
|
||||
class DummyFlow(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -360,7 +359,7 @@ def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterab
|
|||
self.looping += 1
|
||||
|
||||
app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
assert app.root.iter[0 if use_list else "0"].counter == 1
|
||||
assert app.root.iter[1 if use_list else "1"].counter == 0
|
||||
assert app.root.iter[2 if use_list else "2"].counter == 0
|
||||
|
@ -368,7 +367,7 @@ def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterab
|
|||
|
||||
app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
|
||||
app.root.restarting = True
|
||||
runtime_cls(app, start_server=False).dispatch()
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
if run_once_iterable:
|
||||
expected_value = 1
|
||||
|
|
Loading…
Reference in New Issue