lightning/tests/tests_app/utilities/test_proxies.py

796 lines
25 KiB
Python

import logging
import os
import pathlib
import sys
import time
import traceback
from copy import deepcopy
from queue import Empty
from unittest import mock
from unittest.mock import MagicMock, Mock
import pytest
from deepdiff import DeepDiff, Delta
from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Drive, Path
from lightning_app.storage.path import _artifacts_path
from lightning_app.storage.requests import _GetRequest
from lightning_app.testing.helpers import _MockQueue, EmptyFlow
from lightning_app.utilities.component import _convert_paths_after_init
from lightning_app.utilities.enum import AppStage, CacheCallsKeys, WorkFailureReasons, WorkStageStatus
from lightning_app.utilities.exceptions import CacheMissException, ExitAppException
from lightning_app.utilities.proxies import (
ComponentDelta,
LightningWorkSetAttrProxy,
persist_artifacts,
ProxyWorkRun,
WorkRunner,
WorkStateObserver,
)
logger = logging.getLogger(__name__)
class Work(LightningWork):
def __init__(self, cache_calls=True, parallel=True):
super().__init__(cache_calls=cache_calls, parallel=parallel)
self.counter = 0
def run(self):
self.counter = 1
return 1
def test_lightning_work_setattr():
"""This test valides that the `LightningWorkSetAttrProxy` would push a delta to the `caller_queue` everytime an
attribute from the work state is being changed."""
w = Work()
# prepare
w._name = "root.b"
# create queue
caller_queue = _MockQueue("caller_queue")
def proxy_setattr():
w._setattr_replacement = LightningWorkSetAttrProxy(w._name, w, caller_queue, MagicMock())
proxy_setattr()
w.run()
assert len(caller_queue) == 1
work_proxy_output = caller_queue._queue[0]
assert isinstance(work_proxy_output, ComponentDelta)
assert work_proxy_output.id == w._name
assert work_proxy_output.delta.to_dict() == {"values_changed": {"root['vars']['counter']": {"new_value": 1}}}
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("cache_calls", [False, True])
@mock.patch("lightning_app.utilities.proxies._Copier", MagicMock())
@pytest.mark.skipif(sys.platform == "win32", reason="TODO (@ethanwharris): Fix this on Windows")
def test_work_runner(parallel, cache_calls, *_):
"""This test validates the `WorkRunner` runs the work.run method and properly populates the `delta_queue`,
`error_queue` and `readiness_queue`."""
class Work(LightningWork):
def __init__(self, cache_calls=True, parallel=True):
super().__init__(cache_calls=cache_calls, parallel=parallel)
self.counter = 0
self.dummy_path = "lit://test"
def run(self):
self.counter = 1
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self.w = Work(cache_calls=cache_calls, parallel=parallel)
def run(self):
pass
class BlockingQueue(_MockQueue):
"""A Mock for the file copier queues that keeps blocking until we want to end the thread."""
keep_blocking = True
def get(self, timeout: int = 0):
while BlockingQueue.keep_blocking:
pass
# A dummy request so the Copier gets something to process without an error
return _GetRequest(source="src", name="dummy_path", path="test", hash="123", destination="dst")
app = LightningApp(Flow())
work = app.root.w
caller_queue = _MockQueue("caller_queue")
delta_queue = _MockQueue("delta_queue")
readiness_queue = _MockQueue("readiness_queue")
error_queue = _MockQueue("error_queue")
request_queue = _MockQueue("request_queue")
response_queue = _MockQueue("response_queue")
copy_request_queue = BlockingQueue("copy_request_queue")
copy_response_queue = BlockingQueue("copy_response_queue")
call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c"
work._calls[call_hash] = {
"args": (),
"kwargs": {},
"call_hash": call_hash,
"run_started_counter": 1,
"statuses": [],
}
caller_queue.put(
{
"args": (),
"kwargs": {},
"call_hash": call_hash,
"state": work.state,
}
)
work_runner = WorkRunner(
work,
work.name,
caller_queue,
delta_queue,
readiness_queue,
error_queue,
request_queue,
response_queue,
copy_request_queue,
copy_response_queue,
)
try:
work_runner()
except (Empty, Exception):
pass
assert readiness_queue._queue[0]
if parallel:
assert isinstance(error_queue._queue[0], Exception)
else:
assert isinstance(error_queue._queue[0], Empty)
assert len(delta_queue._queue) in [3, 4]
res = delta_queue._queue[0].delta.to_dict()["iterable_item_added"]
assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running"
assert delta_queue._queue[1].delta.to_dict() == {
"values_changed": {"root['vars']['counter']": {"new_value": 1}}
}
index = 3 if len(delta_queue._queue) == 4 else 2
res = delta_queue._queue[index].delta.to_dict()["dictionary_item_added"]
assert res[f"root['calls']['{call_hash}']['ret']"] is None
# Stop blocking and let the thread join
BlockingQueue.keep_blocking = False
work_runner.copier.join()
def test_pathlike_as_argument_to_run_method_warns(tmpdir):
"""Test that Lightning Produces a special warning for strings that look like paths."""
# all these paths are not proper paths or don't have a file or folder that exists
no_warning_expected = (
"looks/like/path",
pathlib.Path("looks/like/path"),
"i am not a path",
1,
Path("lightning/path"),
)
for path in no_warning_expected:
_pass_path_argument_to_work_and_test_warning(path=path, warning_expected=False)
# warn if it looks like a folder and the folder exists
_pass_path_argument_to_work_and_test_warning(path=tmpdir, warning_expected=True)
# warn if it looks like a string or pathlib Path and the file exists
file = pathlib.Path(tmpdir, "file_exists.txt")
file.write_text("test")
assert os.path.exists(file)
_pass_path_argument_to_work_and_test_warning(path=file, warning_expected=True)
_pass_path_argument_to_work_and_test_warning(path=str(file), warning_expected=True)
# do not warn if the path is wrapped in Lightning Path (and the file exists)
file = Path(tmpdir, "file_exists.txt")
file.write_text("test")
assert os.path.exists(file)
_pass_path_argument_to_work_and_test_warning(path=file, warning_expected=False)
def _pass_path_argument_to_work_and_test_warning(path, warning_expected):
class WarnRunPathWork(LightningWork):
def run(self, *args, **kwargs):
pass
class Flow(EmptyFlow):
def __init__(self):
super().__init__()
self.work = WarnRunPathWork()
flow = Flow()
work = flow.work
proxy_run = ProxyWorkRun(work.run, "some", work, Mock())
warn_ctx = pytest.warns(UserWarning, match="You passed a the value") if warning_expected else pytest.warns(None)
with warn_ctx as record:
with pytest.raises(CacheMissException):
proxy_run(path)
assert warning_expected or all("You passed a the value" not in str(msg.message) for msg in record)
class WorkTimeout(LightningWork):
def __init__(self):
super().__init__(parallel=True, start_with_flow=False)
self.counter = 0
def run(self):
self.counter += 1
class FlowTimeout(LightningFlow):
def __init__(self):
super().__init__()
self.counter = 0
self.work = WorkTimeout()
def run(self):
if not self.work.has_started:
self.work.run()
if self.work.has_timeout:
self.stop()
class WorkRunnerPatch(WorkRunner):
counter = 0
def __call__(self):
call_hash = "fe3fa0f"
while True:
try:
called = self.caller_queue.get()
self.work.set_state(called["state"])
state = deepcopy(self.work.state)
self.work._calls[call_hash]["statuses"].append(
{
"name": self.work.name,
"stage": WorkStageStatus.FAILED,
"reason": WorkFailureReasons.TIMEOUT,
"timestamp": time.time(),
"message": None,
}
)
self.delta_queue.put(
ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state, verbose_level=2)))
)
self.counter += 1
except Exception as e:
logger.error(traceback.format_exc())
self.error_queue.put(e)
raise ExitAppException
@mock.patch("lightning_app.runners.backends.mp_process.WorkRunner", WorkRunnerPatch)
def test_proxy_timeout():
app = LightningApp(FlowTimeout(), log_level="debug")
MultiProcessRuntime(app, start_server=False).dispatch()
call_hash = app.root.work._calls[CacheCallsKeys.LATEST_CALL_HASH]
assert len(app.root.work._calls[call_hash]["statuses"]) == 3
assert app.root.work._calls[call_hash]["statuses"][0]["stage"] == "pending"
assert app.root.work._calls[call_hash]["statuses"][1]["stage"] == "failed"
assert app.root.work._calls[call_hash]["statuses"][2]["stage"] == "stopped"
@mock.patch("lightning_app.utilities.proxies._Copier")
def test_path_argument_to_transfer(*_):
"""Test that any Lightning Path objects passed to the run method get transferred automatically (if they
exist)."""
class TransferPathWork(LightningWork):
def run(self, *args, **kwargs):
raise ExitAppException
work = TransferPathWork()
path1 = Path("exists-locally.txt")
path2 = Path("exists-remotely.txt")
path3 = Path("exists-nowhere.txt")
path1.get = Mock()
path2.get = Mock()
path3.get = Mock()
path1.exists_remote = Mock(return_value=False)
path2.exists_remote = Mock(return_value=True)
path3.exists_remote = Mock(return_value=False)
path1._origin = "origin"
path2._origin = "origin"
path3._origin = "origin"
call = {
"args": (path1, path2),
"kwargs": {"path3": path3},
"call_hash": "any",
"state": {
"vars": {"_paths": {}, "_urls": {}},
"calls": {
CacheCallsKeys.LATEST_CALL_HASH: "any",
"any": {
"name": "run",
"call_hash": "any",
"use_args": False,
"statuses": [{"stage": "requesting", "message": None, "reason": None, "timestamp": 1}],
},
},
"changes": {},
},
}
caller_queue = _MockQueue()
caller_queue.put(call)
runner = WorkRunner(
work=work,
work_name="name",
caller_queue=caller_queue,
delta_queue=_MockQueue(),
readiness_queue=_MockQueue(),
error_queue=_MockQueue(),
request_queue=_MockQueue(),
response_queue=_MockQueue(),
copy_request_queue=_MockQueue(),
copy_response_queue=_MockQueue(),
)
try:
runner()
except ExitAppException:
pass
path1.exists_remote.assert_called_once()
path1.get.assert_not_called()
path2.exists_remote.assert_called_once()
path2.get.assert_called_once()
path3.exists_remote.assert_called()
path3.get.assert_not_called()
@pytest.mark.parametrize(
"origin,exists_remote,expected_get",
[
(None, False, False),
("root.work", True, False),
("root.work", False, False),
("origin", True, True),
],
)
@mock.patch("lightning_app.utilities.proxies._Copier")
def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get):
"""Test that any Lightning Path objects passed to the run method get transferred automatically (if they
exist)."""
path_mock = Mock()
path_mock.origin_name = origin
path_mock.exists_remote = Mock(return_value=exists_remote)
class TransferPathWork(LightningWork):
def __init__(self):
super().__init__()
self.path = Path("test-path.txt")
def run(self):
raise ExitAppException
def __getattr__(self, item):
if item == "path":
return path_mock
return super().__getattr__(item)
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self.work = TransferPathWork()
def run(self):
self.work.run()
flow = Flow()
_convert_paths_after_init(flow)
call = {
"args": (),
"kwargs": {},
"call_hash": "any",
"state": {
"vars": {"_paths": flow.work._paths, "_urls": {}},
"calls": {
CacheCallsKeys.LATEST_CALL_HASH: "any",
"any": {
"name": "run",
"call_hash": "any",
"use_args": False,
"statuses": [{"stage": "requesting", "message": None, "reason": None, "timestamp": 1}],
},
},
"changes": {},
},
}
caller_queue = _MockQueue()
caller_queue.put(call)
runner = WorkRunner(
work=flow.work,
work_name=flow.work.name,
caller_queue=caller_queue,
delta_queue=_MockQueue(),
readiness_queue=_MockQueue(),
error_queue=_MockQueue(),
request_queue=_MockQueue(),
response_queue=_MockQueue(),
copy_request_queue=_MockQueue(),
copy_response_queue=_MockQueue(),
)
try:
runner()
except ExitAppException:
pass
assert path_mock.get.call_count == expected_get
def test_proxy_work_run_paths_replace_origin_lightning_work_by_their_name():
class Work(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.path = None
def run(self, path):
assert isinstance(path._origin, str)
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self.w1 = Work()
self.w = Work()
def run(self):
pass
app = LightningApp(Flow())
work = app.root.w
caller_queue = _MockQueue("caller_queue")
app.root.w1.path = Path(__file__)
assert app.root.w1.path._origin == app.root.w1
ProxyWorkRun(work.run, work.name, work, caller_queue)(path=app.root.w1.path)
assert caller_queue._queue[0]["kwargs"]["path"]._origin == app.root.w1.name
def test_persist_artifacts(tmp_path):
"""Test that the `persist_artifacts` utility copies the artifacts that exist to the persistent storage."""
class ArtifactWork(LightningWork):
def __init__(self):
super().__init__()
self.file = None
self.folder = None
self.not_my_path = None
self.not_exists = None
def run(self):
# single file
self.file = Path(tmp_path, "file.txt")
self.file.write_text("single file")
# folder with files
self.folder = Path(tmp_path, "folder")
self.folder.mkdir()
Path(tmp_path, "folder", "file1.txt").write_text("file 1")
Path(tmp_path, "folder", "file2.txt").write_text("file 2")
# simulate a Path that was synced to this Work from another Work
self.not_my_path = Path(tmp_path, "external.txt")
self.not_my_path.touch()
self.not_my_path._origin = Mock()
self.not_exists = Path(tmp_path, "not-exists")
work = ArtifactWork()
work._name = "root.work"
rel_tmpdir_path = Path(*tmp_path.parts[1:])
assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "file.txt")
assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "folder")
assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "not-exists")
work.run()
with pytest.warns(UserWarning, match="1 artifacts could not be saved because they don't exist"):
persist_artifacts(work)
assert os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "file.txt")
assert os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "folder")
assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "not-exists")
assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "external.txt")
def test_work_state_observer():
"""Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been
handled by the setattr."""
class WorkWithoutSetattr(LightningWork):
def __init__(self):
super().__init__()
self.var = 1
self.list = []
self.dict = {"counter": 0}
def run(self, use_setattr=False, use_containers=False):
if use_setattr:
self.var += 1
if use_containers:
self.list.append(1)
self.dict["counter"] += 1
work = WorkWithoutSetattr()
delta_queue = _MockQueue()
observer = WorkStateObserver(work, delta_queue)
setattr_proxy = LightningWorkSetAttrProxy(
work=work,
work_name="work_name",
delta_queue=delta_queue,
state_observer=observer,
)
work._setattr_replacement = setattr_proxy
##############################
# 1. Simulate no state changes
##############################
work.run(use_setattr=False, use_containers=False)
assert len(delta_queue) == 0
############################
# 2. Simulate a setattr call
############################
work.run(use_setattr=True, use_containers=False)
# this is necessary only in this test where we simulate the calls
work._calls.clear()
work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None})
delta = delta_queue.get().delta.to_dict()
assert delta["values_changed"] == {"root['vars']['var']": {"new_value": 2}}
assert len(observer._delta_memory) == 1
# The observer should not trigger any deltas being sent and only consume the delta memory
assert len(delta_queue) == 0
observer.run_once()
assert len(delta_queue) == 0
assert not observer._delta_memory
################################
# 3. Simulate a container update
################################
work.run(use_setattr=False, use_containers=True)
assert len(delta_queue) == 0
assert not observer._delta_memory
observer.run_once()
observer.run_once() # multiple runs should not affect how many deltas are sent unless there are changes
delta = delta_queue.get().delta.to_dict()
assert delta["values_changed"] == {"root['vars']['dict']['counter']": {"new_value": 1}}
assert delta["iterable_item_added"] == {"root['vars']['list'][0]": 1}
##########################
# 4. Simulate both updates
##########################
work.run(use_setattr=True, use_containers=True)
# this is necessary only in this test where we siumulate the calls
work._calls.clear()
work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None})
delta = delta_queue.get().delta.to_dict()
assert delta == {"values_changed": {"root['vars']['var']": {"new_value": 3}}}
assert len(delta_queue) == 0
assert len(observer._delta_memory) == 1
observer.run_once()
delta = delta_queue.get().delta.to_dict()
assert delta["values_changed"] == {"root['vars']['dict']['counter']": {"new_value": 2}}
assert delta["iterable_item_added"] == {"root['vars']['list'][1]": 1}
assert len(delta_queue) == 0
assert not observer._delta_memory
class WorkState(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.vars = []
self.counter = 0
def run(self, *args):
for counter in range(1, 11):
self.vars.append(counter)
self.counter = counter
class FlowState(LightningFlow):
def __init__(self):
super().__init__()
self.w = WorkState()
self.counter = 1
def run(self):
self.w.run()
if self.counter == 1:
if len(self.w.vars) == 10 and self.w.counter == 10:
self.w.vars = []
self.w.counter = 0
self.w.run("")
self.counter = 2
elif self.counter == 2:
if len(self.w.vars) == 10 and self.w.counter == 10:
self.stop()
def test_state_observer():
app = LightningApp(FlowState())
MultiProcessRuntime(app, start_server=False).dispatch()
@pytest.mark.parametrize(
"environment, expected_ip_addr", [({}, "127.0.0.1"), ({"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5")]
)
def test_work_runner_sets_internal_ip(environment, expected_ip_addr):
"""Test that the WorkRunner updates the internal ip address as soon as the Work starts running."""
class Work(LightningWork):
def run(self):
pass
work = Work()
work_runner = WorkRunner(
work,
work.name,
caller_queue=_MockQueue("caller_queue"),
delta_queue=Mock(),
readiness_queue=Mock(),
error_queue=Mock(),
request_queue=Mock(),
response_queue=Mock(),
copy_request_queue=Mock(),
copy_response_queue=Mock(),
)
# Make a fake call
call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c"
work._calls[call_hash] = {
"args": (),
"kwargs": {},
"call_hash": call_hash,
"run_started_counter": 1,
"statuses": [],
}
work_runner.caller_queue.put(
{
"args": (),
"kwargs": {},
"call_hash": call_hash,
"state": work.state,
}
)
with mock.patch.dict(os.environ, environment, clear=True):
work_runner.setup()
# The internal ip address only becomes available once the hardware is up / the work is running.
assert work.internal_ip == ""
try:
work_runner.run_once()
except Empty:
pass
assert work.internal_ip == expected_ip_addr
class WorkBi(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.finished = False
self.counter = 0
self.counter_2 = 0
def run(self):
while not self.finished:
self.counter_2 += 1
time.sleep(0.1)
self.counter = -1
time.sleep(1)
class FlowBi(LightningFlow):
def __init__(self):
super().__init__()
self.w = WorkBi()
def run(self):
self.w.run()
if not self.w.finished:
self.w.counter += 1
if self.w.counter > 3:
self.w.finished = True
if self.w.counter == -1 and self.w.has_succeeded:
self.stop()
def test_bi_directional_proxy():
app = LightningApp(FlowBi())
MultiProcessRuntime(app, start_server=False).dispatch()
class WorkBi2(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.finished = False
self.counter = 0
self.d = {}
def run(self):
self.counter -= 1
while not self.finished:
self.counter -= 1
time.sleep(1)
class FlowBi2(LightningFlow):
def __init__(self):
super().__init__()
self.w = WorkBi2()
def run(self):
self.w.run()
if self.w.counter == 1:
self.w.d["self.w.counter"] = 0
if not self.w.finished:
self.w.counter += 1
def test_bi_directional_proxy_forbidden(monkeypatch):
mock = MagicMock()
monkeypatch.setattr(sys, "exit", mock)
app = LightningApp(FlowBi2())
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.stage == AppStage.FAILED
assert "A forbidden operation to update the work" in str(app.exception)
class WorkDrive(LightningFlow):
def __init__(self, drive):
super().__init__()
self.drive = drive
self.path = Path("data")
def run(self):
pass
class FlowDrive(LightningFlow):
def __init__(self):
super().__init__()
self.data = Drive("lit://data")
self.counter = 0
def run(self):
if not hasattr(self, "w"):
self.w = WorkDrive(self.data)
self.counter += 1
def test_bi_directional_proxy_filtering():
app = LightningApp(FlowDrive())
app.root.run()
assert app._extract_vars_from_component_name(app.root.w.name, app.state) == {}