2022-06-30 20:43:04 +00:00
|
|
|
import os
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from lightning_app import LightningApp, LightningFlow, LightningWork
|
|
|
|
from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime
|
|
|
|
from lightning_app.storage.payload import Payload
|
|
|
|
from lightning_app.structures import Dict, List
|
|
|
|
from lightning_app.testing.helpers import EmptyFlow
|
2022-08-03 13:47:16 +00:00
|
|
|
from lightning_app.utilities.enum import CacheCallsKeys, WorkStageStatus
|
2022-06-30 20:43:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_dict():
|
|
|
|
class WorkA(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(port=1)
|
|
|
|
self.c = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class A(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.dict = Dict(**{"work_a": WorkA(), "work_b": WorkA(), "work_c": WorkA(), "work_d": WorkA()})
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
flow = A()
|
|
|
|
|
|
|
|
# TODO: these assertions are wrong, the works are getting added under "flows" instead of "works"
|
|
|
|
# state
|
|
|
|
assert len(flow.state["structures"]["dict"]["works"]) == len(flow.dict) == 4
|
|
|
|
assert list(flow.state["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"]
|
|
|
|
assert all(
|
|
|
|
flow.state["structures"]["dict"]["works"][f"work_{k}"]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
assert all(
|
2022-08-03 13:47:16 +00:00
|
|
|
flow.state["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None}
|
2022-06-30 20:43:04 +00:00
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
assert all(flow.state["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {} for k in ("a", "b", "c", "d"))
|
|
|
|
|
|
|
|
# state_vars
|
|
|
|
assert len(flow.state_vars["structures"]["dict"]["works"]) == len(flow.dict) == 4
|
|
|
|
assert list(flow.state_vars["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"]
|
|
|
|
assert all(
|
|
|
|
flow.state_vars["structures"]["dict"]["works"][f"work_{k}"]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
|
|
|
|
# state_with_changes
|
|
|
|
assert len(flow.state_with_changes["structures"]["dict"]["works"]) == len(flow.dict) == 4
|
|
|
|
assert list(flow.state_with_changes["structures"]["dict"]["works"].keys()) == [
|
|
|
|
"work_a",
|
|
|
|
"work_b",
|
|
|
|
"work_c",
|
|
|
|
"work_d",
|
|
|
|
]
|
|
|
|
assert all(
|
|
|
|
flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
assert all(
|
2022-08-03 13:47:16 +00:00
|
|
|
flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["calls"]
|
|
|
|
== {CacheCallsKeys.LATEST_CALL_HASH: None}
|
2022-06-30 20:43:04 +00:00
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
assert all(
|
|
|
|
flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {}
|
|
|
|
for k in ("a", "b", "c", "d")
|
|
|
|
)
|
|
|
|
|
|
|
|
# set_state
|
|
|
|
state = deepcopy(flow.state)
|
|
|
|
state["structures"]["dict"]["works"]["work_b"]["vars"]["c"] = 1
|
|
|
|
flow.set_state(state)
|
|
|
|
assert flow.dict["work_b"].c == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_dict_name():
|
|
|
|
d = Dict(a=EmptyFlow(), b=EmptyFlow())
|
|
|
|
assert d.name == "root"
|
|
|
|
assert d["a"].name == "root.a"
|
|
|
|
assert d["b"].name == "root.b"
|
|
|
|
|
|
|
|
class RootFlow(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.dict = Dict(x=EmptyFlow(), y=EmptyFlow())
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
root = RootFlow()
|
|
|
|
assert root.name == "root"
|
|
|
|
assert root.dict.name == "root.dict"
|
|
|
|
assert root.dict["x"].name == "root.dict.x"
|
|
|
|
assert root.dict["y"].name == "root.dict.y"
|
|
|
|
|
|
|
|
|
|
|
|
def test_list():
|
|
|
|
class WorkA(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(port=1)
|
|
|
|
self.c = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class A(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.list = List(WorkA(), WorkA(), WorkA(), WorkA())
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
flow = A()
|
|
|
|
|
|
|
|
# TODO: these assertions are wrong, the works are getting added under "flows" instead of "works"
|
|
|
|
# state
|
|
|
|
assert len(flow.state["structures"]["list"]["works"]) == len(flow.list) == 4
|
|
|
|
assert list(flow.state["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
|
|
|
|
assert all(
|
|
|
|
flow.state["structures"]["list"]["works"][str(i)]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for i in range(4)
|
|
|
|
)
|
|
|
|
assert all(
|
2022-08-03 13:47:16 +00:00
|
|
|
flow.state["structures"]["list"]["works"][str(i)]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None}
|
|
|
|
for i in range(4)
|
2022-06-30 20:43:04 +00:00
|
|
|
)
|
|
|
|
assert all(flow.state["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4))
|
|
|
|
|
|
|
|
# state_vars
|
|
|
|
assert len(flow.state_vars["structures"]["list"]["works"]) == len(flow.list) == 4
|
|
|
|
assert list(flow.state_vars["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
|
|
|
|
assert all(
|
|
|
|
flow.state_vars["structures"]["list"]["works"][str(i)]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for i in range(4)
|
|
|
|
)
|
|
|
|
|
|
|
|
# state_with_changes
|
|
|
|
assert len(flow.state_with_changes["structures"]["list"]["works"]) == len(flow.list) == 4
|
|
|
|
assert list(flow.state_with_changes["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
|
|
|
|
assert all(
|
|
|
|
flow.state_with_changes["structures"]["list"]["works"][str(i)]["vars"]
|
|
|
|
== {
|
|
|
|
"c": 0,
|
|
|
|
"_url": "",
|
|
|
|
"_future_url": "",
|
|
|
|
"_port": 1,
|
|
|
|
"_host": "127.0.0.1",
|
|
|
|
"_paths": {},
|
|
|
|
"_restarting": False,
|
|
|
|
"_internal_ip": "",
|
2022-10-04 19:46:44 +00:00
|
|
|
"_cloud_compute": {
|
|
|
|
"type": "__cloud_compute__",
|
|
|
|
"name": "default",
|
|
|
|
"disk_size": 0,
|
|
|
|
"idle_timeout": None,
|
2022-10-20 03:24:27 +00:00
|
|
|
"mounts": None,
|
2022-10-04 19:46:44 +00:00
|
|
|
"shm_size": 0,
|
|
|
|
"_internal_id": "default",
|
|
|
|
},
|
2022-06-30 20:43:04 +00:00
|
|
|
}
|
|
|
|
for i in range(4)
|
|
|
|
)
|
|
|
|
assert all(
|
2022-08-03 13:47:16 +00:00
|
|
|
flow.state_with_changes["structures"]["list"]["works"][str(i)]["calls"]
|
|
|
|
== {CacheCallsKeys.LATEST_CALL_HASH: None}
|
2022-06-30 20:43:04 +00:00
|
|
|
for i in range(4)
|
|
|
|
)
|
|
|
|
assert all(flow.state_with_changes["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4))
|
|
|
|
|
|
|
|
# set_state
|
|
|
|
state = deepcopy(flow.state)
|
|
|
|
state["structures"]["list"]["works"]["0"]["vars"]["c"] = 1
|
|
|
|
flow.set_state(state)
|
|
|
|
assert flow.list[0].c == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_list_name():
|
|
|
|
lst = List(EmptyFlow(), EmptyFlow())
|
|
|
|
assert lst.name == "root"
|
|
|
|
assert lst[0].name == "root.0"
|
|
|
|
assert lst[1].name == "root.1"
|
|
|
|
|
|
|
|
class RootFlow(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.list = List(EmptyFlow(), EmptyFlow())
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
root = RootFlow()
|
|
|
|
assert root.name == "root"
|
|
|
|
assert root.list.name == "root.list"
|
|
|
|
assert root.list[0].name == "root.list.0"
|
|
|
|
assert root.list[1].name == "root.list.1"
|
|
|
|
|
|
|
|
|
|
|
|
class CounterWork(LightningWork):
|
|
|
|
def __init__(self, cache_calls, parallel=False):
|
|
|
|
super().__init__(cache_calls=cache_calls, parallel=parallel)
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self.counter += 1
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(True, reason="tchaton: Resolve this test.")
|
|
|
|
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime])
|
|
|
|
@pytest.mark.parametrize("run_once_iterable", [False, True])
|
|
|
|
@pytest.mark.parametrize("cache_calls", [False, True])
|
|
|
|
@pytest.mark.parametrize("use_list", [False, True])
|
|
|
|
def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterable, cache_calls, use_list):
|
|
|
|
class DummyFlow(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class RootFlow(LightningFlow):
|
|
|
|
def __init__(self, use_list, run_once_iterable, cache_calls):
|
|
|
|
super().__init__()
|
|
|
|
self.looping = 0
|
|
|
|
self.run_once_iterable = run_once_iterable
|
|
|
|
self.restarting = False
|
|
|
|
if use_list:
|
|
|
|
self.iter = List(
|
|
|
|
CounterWork(cache_calls),
|
|
|
|
CounterWork(cache_calls),
|
|
|
|
CounterWork(cache_calls),
|
|
|
|
CounterWork(cache_calls),
|
|
|
|
DummyFlow(),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.iter = Dict(
|
|
|
|
**{
|
|
|
|
"0": CounterWork(cache_calls),
|
|
|
|
"1": CounterWork(cache_calls),
|
|
|
|
"2": CounterWork(cache_calls),
|
|
|
|
"3": CounterWork(cache_calls),
|
|
|
|
"4": DummyFlow(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
for work_idx, work in self.experimental_iterate(enumerate(self.iter), run_once=self.run_once_iterable):
|
|
|
|
if not self.restarting and work_idx == 1:
|
|
|
|
# gives time to the delta to be sent.
|
|
|
|
self._exit()
|
|
|
|
if isinstance(work, str) and isinstance(self.iter, Dict):
|
|
|
|
work = self.iter[work]
|
|
|
|
work.run()
|
|
|
|
if self.looping > 0:
|
|
|
|
self._exit()
|
|
|
|
self.looping += 1
|
|
|
|
|
|
|
|
app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
|
|
|
|
runtime_cls(app, start_server=False).dispatch()
|
|
|
|
assert app.root.iter[0 if use_list else "0"].counter == 1
|
|
|
|
assert app.root.iter[1 if use_list else "1"].counter == 0
|
|
|
|
assert app.root.iter[2 if use_list else "2"].counter == 0
|
|
|
|
assert app.root.iter[3 if use_list else "3"].counter == 0
|
|
|
|
|
|
|
|
app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
|
|
|
|
app.root.restarting = True
|
|
|
|
runtime_cls(app, start_server=False).dispatch()
|
|
|
|
|
|
|
|
if run_once_iterable:
|
|
|
|
expected_value = 1
|
|
|
|
else:
|
|
|
|
expected_value = 1 if cache_calls else 2
|
|
|
|
assert app.root.iter[0 if use_list else "0"].counter == expected_value
|
|
|
|
assert app.root.iter[1 if use_list else "1"].counter == expected_value
|
|
|
|
assert app.root.iter[2 if use_list else "2"].counter == expected_value
|
|
|
|
assert app.root.iter[3 if use_list else "3"].counter == expected_value
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointCounter(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(cache_calls=False)
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self.counter += 1
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointFlow(LightningFlow):
|
|
|
|
def __init__(self, collection, depth=0, exit=11):
|
|
|
|
super().__init__()
|
|
|
|
self.depth = depth
|
|
|
|
self.exit = exit
|
|
|
|
if depth == 0:
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
if depth >= 4:
|
|
|
|
self.collection = collection
|
|
|
|
else:
|
|
|
|
self.flow = CheckpointFlow(collection, depth + 1)
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
if hasattr(self, "counter"):
|
|
|
|
self.counter += 1
|
|
|
|
if self.counter >= self.exit:
|
|
|
|
self._exit()
|
|
|
|
if self.depth >= 4:
|
|
|
|
self.collection.run()
|
|
|
|
else:
|
|
|
|
self.flow.run()
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleCounterWork(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self.counter += 1
|
|
|
|
|
|
|
|
|
|
|
|
class FlowDict(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.dict = Dict()
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
if "w" not in self.dict:
|
|
|
|
self.dict["w"] = SimpleCounterWork()
|
|
|
|
|
|
|
|
if self.dict["w"].status.stage == WorkStageStatus.SUCCEEDED:
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
self.dict["w"].run()
|
|
|
|
|
|
|
|
|
|
|
|
def test_dict_with_queues():
|
|
|
|
|
|
|
|
app = LightningApp(FlowDict())
|
|
|
|
MultiProcessRuntime(app, start_server=False).dispatch()
|
|
|
|
|
|
|
|
|
|
|
|
class FlowList(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.list = List()
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
if not len(self.list):
|
|
|
|
self.list.append(SimpleCounterWork())
|
|
|
|
|
|
|
|
if self.list[-1].status.stage == WorkStageStatus.SUCCEEDED:
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
self.list[-1].run()
|
|
|
|
|
|
|
|
|
|
|
|
def test_list_with_queues():
|
|
|
|
|
|
|
|
app = LightningApp(FlowList())
|
|
|
|
MultiProcessRuntime(app, start_server=False).dispatch()
|
|
|
|
|
|
|
|
|
|
|
|
class WorkS(LightningWork):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.payload = None
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self.payload = Payload(2)
|
|
|
|
|
|
|
|
|
|
|
|
class WorkD(LightningWork):
|
|
|
|
def run(self, payload):
|
|
|
|
assert payload.value == 2
|
|
|
|
|
|
|
|
|
|
|
|
class FlowPayload(LightningFlow):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.src = WorkS()
|
|
|
|
self.dst = Dict(**{"0": WorkD(parallel=True), "1": WorkD(parallel=True)})
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
self.src.run()
|
|
|
|
if self.src.payload:
|
|
|
|
for work in self.dst.values():
|
|
|
|
work.run(self.src.payload)
|
|
|
|
if all(w.has_succeeded for w in self.dst.values()):
|
|
|
|
self._exit()
|
|
|
|
|
|
|
|
|
|
|
|
def test_structures_with_payload():
|
|
|
|
app = LightningApp(FlowPayload(), debug=True)
|
|
|
|
MultiProcessRuntime(app, start_server=False).dispatch()
|
|
|
|
os.remove("payload")
|