155 lines
4.6 KiB
Python
155 lines
4.6 KiB
Python
import os
|
|
import pathlib
|
|
import pickle
|
|
from copy import deepcopy
|
|
from unittest import mock
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from lightning_app import LightningApp, LightningFlow, LightningWork
|
|
from lightning_app.runners.multiprocess import MultiProcessRuntime
|
|
from lightning_app.storage.payload import Payload
|
|
from lightning_app.storage.requests import _GetRequest
|
|
|
|
|
|
def test_payload_copy():
|
|
"""Test that Payload creates an exact copy when passing a Payload instance to the constructor."""
|
|
payload = Payload(None)
|
|
payload._origin = "origin"
|
|
payload._consumer = "consumer"
|
|
payload._request_queue = "MockQueue"
|
|
payload._response_queue = "MockQueue"
|
|
payload_copy = deepcopy(payload)
|
|
assert payload_copy._origin == payload._origin
|
|
assert payload_copy._consumer == payload._consumer
|
|
assert payload_copy._request_queue == payload._request_queue
|
|
assert payload_copy._response_queue == payload._response_queue
|
|
|
|
|
|
def test_payload_pickable():
|
|
payload = Payload("MyObject")
|
|
payload._origin = "root.x.y.z"
|
|
payload._consumer = "root.p.q.r"
|
|
payload._name = "var_a"
|
|
loaded = pickle.loads(pickle.dumps(payload))
|
|
|
|
assert isinstance(loaded, Payload)
|
|
assert loaded._origin == payload._origin
|
|
assert loaded._consumer == payload._consumer
|
|
assert loaded._name == payload._name
|
|
assert loaded._request_queue is None
|
|
assert loaded._response_queue is None
|
|
|
|
|
|
def test_path_attach_queues():
|
|
path = Payload(None)
|
|
request_queue = Mock()
|
|
response_queue = Mock()
|
|
path._attach_queues(request_queue=request_queue, response_queue=response_queue)
|
|
assert path._request_queue is request_queue
|
|
assert path._response_queue is response_queue
|
|
|
|
|
|
class Work(LightningWork):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.var_a = Payload(None)
|
|
|
|
def run(self):
|
|
pass
|
|
|
|
|
|
def test_payload_in_init():
|
|
with pytest.raises(
|
|
AttributeError, match="The Payload object should be set only within the run method of the work."
|
|
):
|
|
Work()
|
|
|
|
|
|
class WorkRun(LightningWork):
|
|
def __init__(self, tmpdir):
|
|
super().__init__()
|
|
self.var_a = None
|
|
self.tmpdir = tmpdir
|
|
|
|
def run(self):
|
|
self.var_a = Payload("something")
|
|
assert self.var_a.name == "var_a"
|
|
assert self.var_a._origin == "root.a"
|
|
assert self.var_a.hash == "9bd514ad51fc33d895c50657acd0f0582301cf3e"
|
|
source_path = pathlib.Path(self.tmpdir, self.var_a.name)
|
|
assert not source_path.exists()
|
|
response = self.var_a._handle_get_request(
|
|
self,
|
|
_GetRequest(
|
|
name="var_a",
|
|
hash=self.var_a.hash,
|
|
source="root.a",
|
|
path=str(source_path),
|
|
destination="root",
|
|
),
|
|
)
|
|
assert source_path.exists()
|
|
assert self.var_a.load(str(source_path)) == "something"
|
|
assert not response.exception
|
|
|
|
|
|
def test_payload_in_run(tmpdir):
|
|
work = WorkRun(str(tmpdir))
|
|
work._name = "root.a"
|
|
work.run()
|
|
|
|
|
|
class Sender(LightningWork):
|
|
def __init__(self):
|
|
super().__init__(parallel=True)
|
|
self.value_all = None
|
|
self.value_b = None
|
|
self.value_c = None
|
|
|
|
def run(self):
|
|
self.value_all = Payload(["A", "B", "C"])
|
|
self.value_b = Payload("B")
|
|
self.value_c = Payload("C")
|
|
|
|
|
|
class WorkReceive(LightningWork):
|
|
def __init__(self, expected):
|
|
super().__init__(parallel=True)
|
|
self.expected = expected
|
|
|
|
def run(self, generated):
|
|
assert generated.value == self.expected
|
|
|
|
|
|
class Flow(LightningFlow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sender = Sender()
|
|
self.receiver_all = WorkReceive(["A", "B", "C"])
|
|
self.receiver_b = WorkReceive("B")
|
|
self.receiver_c = WorkReceive("C")
|
|
|
|
def run(self):
|
|
self.sender.run()
|
|
if self.sender.value_all:
|
|
self.receiver_all.run(self.sender.value_all)
|
|
if self.sender.value_b:
|
|
self.receiver_b.run(self.sender.value_b)
|
|
if self.sender.value_c:
|
|
self.receiver_c.run(self.sender.value_c)
|
|
if self.receiver_all.has_succeeded and self.receiver_b.has_succeeded and self.receiver_c.has_succeeded:
|
|
self._exit()
|
|
|
|
|
|
def test_payload_works(tmpdir):
|
|
"""This tests validates the payload api can be used to transfer return values from a work to another."""
|
|
with mock.patch("lightning_app.storage.path._storage_root_dir", lambda: pathlib.Path(tmpdir)):
|
|
app = LightningApp(Flow(), log_level="debug")
|
|
MultiProcessRuntime(app, start_server=False).dispatch()
|
|
|
|
os.remove("value_all")
|
|
os.remove("value_b")
|
|
os.remove("value_c")
|