import os from copy import deepcopy import pytest from lightning_app import LightningApp, LightningFlow, LightningWork from lightning_app.runners import MultiProcessRuntime, SingleProcessRuntime from lightning_app.storage.payload import Payload from lightning_app.structures import Dict, List from lightning_app.testing.helpers import EmptyFlow from lightning_app.utilities.enum import CacheCallsKeys, WorkStageStatus def test_dict(): class WorkA(LightningWork): def __init__(self): super().__init__(port=1) self.c = 0 def run(self): pass class A(LightningFlow): def __init__(self): super().__init__() self.dict = Dict(**{"work_a": WorkA(), "work_b": WorkA(), "work_c": WorkA(), "work_d": WorkA()}) def run(self): pass flow = A() # TODO: these assertions are wrong, the works are getting added under "flows" instead of "works" # state assert len(flow.state["structures"]["dict"]["works"]) == len(flow.dict) == 4 assert list(flow.state["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"] assert all( flow.state["structures"]["dict"]["works"][f"work_{k}"]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for k in ("a", "b", "c", "d") ) assert all( flow.state["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} for k in ("a", "b", "c", "d") ) assert all(flow.state["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {} for k in ("a", "b", "c", "d")) # state_vars assert len(flow.state_vars["structures"]["dict"]["works"]) == len(flow.dict) == 4 assert list(flow.state_vars["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"] assert all( flow.state_vars["structures"]["dict"]["works"][f"work_{k}"]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for k in ("a", "b", "c", "d") ) # state_with_changes assert len(flow.state_with_changes["structures"]["dict"]["works"]) == len(flow.dict) == 4 assert list(flow.state_with_changes["structures"]["dict"]["works"].keys()) == [ "work_a", "work_b", "work_c", "work_d", ] assert all( flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for k in ("a", "b", "c", "d") ) assert all( flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} for k in ("a", "b", "c", "d") ) assert all( flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {} for k in ("a", "b", "c", "d") ) # set_state state = deepcopy(flow.state) state["structures"]["dict"]["works"]["work_b"]["vars"]["c"] = 1 flow.set_state(state) assert flow.dict["work_b"].c == 1 def test_dict_name(): d = Dict(a=EmptyFlow(), b=EmptyFlow()) assert d.name == "root" assert d["a"].name == "root.a" assert d["b"].name == "root.b" class RootFlow(LightningFlow): def __init__(self): super().__init__() self.dict = Dict(x=EmptyFlow(), y=EmptyFlow()) def run(self): pass root = RootFlow() assert root.name == "root" assert root.dict.name == "root.dict" assert root.dict["x"].name == "root.dict.x" assert root.dict["y"].name == "root.dict.y" def test_list(): class WorkA(LightningWork): def __init__(self): super().__init__(port=1) self.c = 0 def run(self): pass class A(LightningFlow): def __init__(self): super().__init__() self.list = List(WorkA(), WorkA(), WorkA(), WorkA()) def run(self): pass flow = A() # TODO: these assertions are wrong, the works are getting added under "flows" instead of "works" # state assert len(flow.state["structures"]["list"]["works"]) == len(flow.list) == 4 assert list(flow.state["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"] assert all( flow.state["structures"]["list"]["works"][str(i)]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for i in range(4) ) assert all( flow.state["structures"]["list"]["works"][str(i)]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} for i in range(4) ) assert all(flow.state["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4)) # state_vars assert len(flow.state_vars["structures"]["list"]["works"]) == len(flow.list) == 4 assert list(flow.state_vars["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"] assert all( flow.state_vars["structures"]["list"]["works"][str(i)]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for i in range(4) ) # state_with_changes assert len(flow.state_with_changes["structures"]["list"]["works"]) == len(flow.list) == 4 assert list(flow.state_with_changes["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"] assert all( flow.state_with_changes["structures"]["list"]["works"][str(i)]["vars"] == { "c": 0, "_url": "", "_future_url": "", "_port": 1, "_host": "127.0.0.1", "_paths": {}, "_restarting": False, "_internal_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "default", "disk_size": 0, "idle_timeout": None, "mounts": None, "shm_size": 0, "_internal_id": "default", }, } for i in range(4) ) assert all( flow.state_with_changes["structures"]["list"]["works"][str(i)]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} for i in range(4) ) assert all(flow.state_with_changes["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4)) # set_state state = deepcopy(flow.state) state["structures"]["list"]["works"]["0"]["vars"]["c"] = 1 flow.set_state(state) assert flow.list[0].c == 1 def test_list_name(): lst = List(EmptyFlow(), EmptyFlow()) assert lst.name == "root" assert lst[0].name == "root.0" assert lst[1].name == "root.1" class RootFlow(LightningFlow): def __init__(self): super().__init__() self.list = List(EmptyFlow(), EmptyFlow()) def run(self): pass root = RootFlow() assert root.name == "root" assert root.list.name == "root.list" assert root.list[0].name == "root.list.0" assert root.list[1].name == "root.list.1" class CounterWork(LightningWork): def __init__(self, cache_calls, parallel=False): super().__init__(cache_calls=cache_calls, parallel=parallel) self.counter = 0 def run(self): self.counter += 1 @pytest.mark.skipif(True, reason="tchaton: Resolve this test.") @pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime]) @pytest.mark.parametrize("run_once_iterable", [False, True]) @pytest.mark.parametrize("cache_calls", [False, True]) @pytest.mark.parametrize("use_list", [False, True]) def test_structure_with_iterate_and_fault_tolerance(runtime_cls, run_once_iterable, cache_calls, use_list): class DummyFlow(LightningFlow): def __init__(self): super().__init__() self.counter = 0 def run(self): pass class RootFlow(LightningFlow): def __init__(self, use_list, run_once_iterable, cache_calls): super().__init__() self.looping = 0 self.run_once_iterable = run_once_iterable self.restarting = False if use_list: self.iter = List( CounterWork(cache_calls), CounterWork(cache_calls), CounterWork(cache_calls), CounterWork(cache_calls), DummyFlow(), ) else: self.iter = Dict( **{ "0": CounterWork(cache_calls), "1": CounterWork(cache_calls), "2": CounterWork(cache_calls), "3": CounterWork(cache_calls), "4": DummyFlow(), } ) def run(self): for work_idx, work in self.experimental_iterate(enumerate(self.iter), run_once=self.run_once_iterable): if not self.restarting and work_idx == 1: # gives time to the delta to be sent. self._exit() if isinstance(work, str) and isinstance(self.iter, Dict): work = self.iter[work] work.run() if self.looping > 0: self._exit() self.looping += 1 app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls)) runtime_cls(app, start_server=False).dispatch() assert app.root.iter[0 if use_list else "0"].counter == 1 assert app.root.iter[1 if use_list else "1"].counter == 0 assert app.root.iter[2 if use_list else "2"].counter == 0 assert app.root.iter[3 if use_list else "3"].counter == 0 app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls)) app.root.restarting = True runtime_cls(app, start_server=False).dispatch() if run_once_iterable: expected_value = 1 else: expected_value = 1 if cache_calls else 2 assert app.root.iter[0 if use_list else "0"].counter == expected_value assert app.root.iter[1 if use_list else "1"].counter == expected_value assert app.root.iter[2 if use_list else "2"].counter == expected_value assert app.root.iter[3 if use_list else "3"].counter == expected_value class CheckpointCounter(LightningWork): def __init__(self): super().__init__(cache_calls=False) self.counter = 0 def run(self): self.counter += 1 class CheckpointFlow(LightningFlow): def __init__(self, collection, depth=0, exit=11): super().__init__() self.depth = depth self.exit = exit if depth == 0: self.counter = 0 if depth >= 4: self.collection = collection else: self.flow = CheckpointFlow(collection, depth + 1) def run(self): if hasattr(self, "counter"): self.counter += 1 if self.counter >= self.exit: self._exit() if self.depth >= 4: self.collection.run() else: self.flow.run() class SimpleCounterWork(LightningWork): def __init__(self): super().__init__() self.counter = 0 def run(self): self.counter += 1 class FlowDict(LightningFlow): def __init__(self): super().__init__() self.dict = Dict() def run(self): if "w" not in self.dict: self.dict["w"] = SimpleCounterWork() if self.dict["w"].status.stage == WorkStageStatus.SUCCEEDED: self._exit() self.dict["w"].run() def test_dict_with_queues(): app = LightningApp(FlowDict()) MultiProcessRuntime(app, start_server=False).dispatch() class FlowList(LightningFlow): def __init__(self): super().__init__() self.list = List() def run(self): if not len(self.list): self.list.append(SimpleCounterWork()) if self.list[-1].status.stage == WorkStageStatus.SUCCEEDED: self._exit() self.list[-1].run() def test_list_with_queues(): app = LightningApp(FlowList()) MultiProcessRuntime(app, start_server=False).dispatch() class WorkS(LightningWork): def __init__(self): super().__init__() self.payload = None def run(self): self.payload = Payload(2) class WorkD(LightningWork): def run(self, payload): assert payload.value == 2 class FlowPayload(LightningFlow): def __init__(self): super().__init__() self.src = WorkS() self.dst = Dict(**{"0": WorkD(parallel=True), "1": WorkD(parallel=True)}) def run(self): self.src.run() if self.src.payload: for work in self.dst.values(): work.run(self.src.payload) if all(w.has_succeeded for w in self.dst.values()): self._exit() def test_structures_with_payload(): app = LightningApp(FlowPayload(), debug=True) MultiProcessRuntime(app, start_server=False).dispatch() os.remove("payload")