lightning/tests/tests_app/storage/test_path.py

723 lines
25 KiB
Python

import json
import os
import pathlib
import pickle
from re import escape
from time import sleep
from unittest import mock, TestCase
from unittest.mock import MagicMock, Mock
import pytest
from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage.path import (
_is_s3fs_available,
artifacts_path,
filesystem,
is_lit_path,
Path,
shared_storage_path,
storage_root_dir,
)
from lightning_app.storage.requests import ExistsResponse, GetResponse
from lightning_app.testing.helpers import EmptyWork, MockQueue, RunIf
from lightning_app.utilities.app_helpers import LightningJSONEncoder
from lightning_app.utilities.component import _context
def test_path_instantiation():
assert Path() == pathlib.Path()
assert Path("a/b") == pathlib.Path("a/b")
assert Path("a", "b") == pathlib.Path("a", "b")
assert Path(pathlib.Path("a"), pathlib.Path("b")) == pathlib.Path("a/b")
assert Path(Path(Path("a/b"))) == pathlib.Path("a/b")
path = Path()
assert path._origin is path._consumer is path._request_queue is path._response_queue is None
folder = Path("x/y/z")
folder._origin = "origin"
folder._consumer = "consumer"
# from parts where the first is a Lightning Path and the other(s) are string
file = Path(folder, "file.txt")
assert file._origin == "origin"
assert file._consumer == "consumer"
# from parts that are instance of Path and have no origin
file = Path(folder, Path("file.txt"))
assert file._origin == "origin"
assert file._consumer == "consumer"
# from parts that are instance of Path and have a different origin than the top folder
filename = Path("file.txt")
filename._origin = "different"
with pytest.raises(TypeError, match="Tried to instantiate a Lightning Path from multiple other Paths"):
Path(folder, filename)
# from parts that are instance of Path and have the SAME origin as the top folder
filename = Path("file.txt")
filename._origin = "origin"
file = Path(folder, filename)
assert file._origin == "origin"
assert file._consumer == "consumer"
def test_path_instantiation_lit():
assert Path("lit://") == storage_root_dir()
assert Path("lit://a/b") == pathlib.Path(storage_root_dir(), "a/b")
assert Path("lit://", "a", "b") == pathlib.Path(storage_root_dir(), "a", "b")
assert Path("lit://", pathlib.Path("a"), pathlib.Path("b")) == pathlib.Path(storage_root_dir(), "a/b")
assert Path(Path(Path("lit://a/b"))) == pathlib.Path(storage_root_dir(), "a", "b")
assert str(Path("lit://lit-path")) == os.path.join(storage_root_dir(), "lit-path")
def test_is_lit_path():
assert not is_lit_path("lit")
assert not is_lit_path(Path("lit"))
assert is_lit_path("lit://")
assert is_lit_path(Path("lit://"))
assert is_lit_path("lit://a/b/c")
assert is_lit_path(Path("lit://a/b/c"))
assert is_lit_path(storage_root_dir())
def test_path_copy():
"""Test that Path creates an exact copy when passing a Path instance to the constructor."""
path = Path("x/y/z")
path._origin = "origin"
path._consumer = "consumer"
path._request_queue = Mock()
path._response_queue = Mock()
path_copy = Path(path)
assert path_copy._origin == path._origin
assert path_copy._consumer == path._consumer
assert path_copy._request_queue == path._request_queue
assert path_copy._response_queue == path._response_queue
def test_path_inheritance():
"""Test that the Lightning Path is a drop-in replacement for pathlib.Path without compromises."""
file = Path("file.txt")
pathlibfile = pathlib.Path("file.txt")
assert file == pathlibfile
assert isinstance(file, Path)
assert isinstance(file, pathlib.Path)
folder = Path("./x/y")
file = folder / "file.txt"
assert isinstance(file, Path)
file.with_suffix(".png")
assert isinstance(file, Path)
def test_path_concatenation():
"""Test that path concatentaions keep the properties of the paths on the right-hand side of the join."""
folder = Path("x/y/z")
folder._origin = "origin"
folder._consumer = "consumer"
other = Path("other")
# test __truediv__ when Path is on the left-hand side
file = folder / other / "more" / "file.txt"
assert file._origin == "origin"
assert file._consumer == "consumer"
# test __rtruediv__ when Path is on the right-hand side
switched = pathlib.Path("/") / folder
assert isinstance(switched, Path)
assert file._origin == "origin"
assert file._consumer == "consumer"
def test_path_with_replacement():
"""Test that the ``Path.with_*`` modifiers keep the properties."""
folder = Path("x", "y", "z")
folder._origin = "origin"
folder._consumer = "consumer"
# with_name
file = folder.with_name("file.txt")
assert str(file) == os.path.join("x", "y", "file.txt")
assert file._origin == "origin"
assert file._consumer == "consumer"
# with_suffix
file = file.with_suffix(".png")
assert str(file) == os.path.join("x", "y", "file.png")
assert file._origin == "origin"
assert file._consumer == "consumer"
# relative_to
rel_path = folder.relative_to("x")
assert str(rel_path) == os.path.join("y", "z")
assert rel_path._origin == "origin"
assert rel_path._consumer == "consumer"
@RunIf(min_python="3.9")
def test_path_with_stem_replacement():
"""Test that the ``Path.with_stem`` modifier keep the properties.
This is only available in Python 3.9+.
"""
file = Path("x", "y", "file.txt")
file._origin = "origin"
file._consumer = "consumer"
file = file.with_stem("text")
assert str(file) == os.path.join("x", "y", "text.txt")
assert file._origin == "origin"
assert file._consumer == "consumer"
def test_path_parents():
"""Test that the ``Path.parent`` and ``Path.parent`` properties return Paths that inherit the origin and
consumer attributes."""
path = Path("a", "b", "c", "d")
path._origin = "origin"
path._consumer = "consumer"
# .parent
assert isinstance(path.parent, Path)
assert str(path.parent) == os.path.join("a", "b", "c")
assert path.parent._origin == "origin"
assert path.parent._consumer == "consumer"
# .parents
assert path.parents == [Path("a", "b", "c"), Path("a", "b"), Path("a"), Path(".")]
assert all(parent._origin == "origin" for parent in path.parents)
assert all(parent._consumer == "consumer" for parent in path.parents)
def test_path_hash():
"""Test that the value of the Path hash is a function of the path name and the origin."""
# a path without origin has no hash
assert Path("one").hash is Path("two").hash is None
# identical paths with identical origins have the same hash
path1 = Path("one")
path2 = Path("one")
path1._origin = "origin1"
path1._consumer = "consumer1"
path2._origin = "origin1"
path1._consumer = "consumer2"
assert path1.hash == path2.hash
# identical paths with different origins have different hash
path2._origin = "origin2"
assert path1.hash != path2.hash
# different paths but same owner yields a different hash
path1 = Path("one")
path2 = Path("other")
path1._origin = "same"
path2._origin = "same"
assert path1.hash != path2.hash
def test_path_pickleable():
path = Path("a/b/c.txt")
path._origin = "root.x.y.z"
path._consumer = "root.p.q.r"
path._request_queue = Mock()
path._response_queue = Mock()
loaded = pickle.loads(pickle.dumps(path))
assert isinstance(loaded, Path)
assert loaded == path
assert loaded._origin == path._origin
assert loaded._consumer == path._consumer
assert loaded._request_queue is None
assert loaded._response_queue is None
def test_path_json_serializable():
path = Path("a/b/c.txt")
path._origin = "root.x.y.z"
path._consumer = "root.p.q.r"
path._request_queue = Mock()
path._response_queue = Mock()
json_dump = json.dumps(path, cls=LightningJSONEncoder)
assert "path" in json_dump
# the replacement of \ is needed for Windows paths
assert str(path).replace("\\", "\\\\") in json_dump
assert "origin_name" in json_dump
assert path._origin in json_dump
assert "consumer_name" in json_dump
assert path._consumer in json_dump
def test_path_to_dict_from_dict():
path = Path("a/b/c.txt")
path._origin = "root.x.y.z"
path._consumer = "root.p.q.r"
path._request_queue = Mock()
path._response_queue = Mock()
path_dict = path.to_dict()
same_path = Path.from_dict(path_dict)
assert same_path == path
assert same_path._origin == path._origin
assert same_path._consumer == path._consumer
assert same_path._request_queue is None
assert same_path._response_queue is None
assert same_path._metadata == path._metadata
def test_path_attach_work():
"""Test that attaching a path to a LighitningWork will make the Work either the origin or a consumer."""
path = Path()
assert path._origin is None
work1 = EmptyWork()
work2 = EmptyWork()
work3 = EmptyWork()
path._attach_work(work=work1)
assert path._origin is work1
# path already has an owner
path._attach_work(work=work2)
assert path._origin is work1
assert path._consumer is work2
# path gets a new consumer
path._attach_work(work=work3)
assert path._origin is work1
assert path._consumer is work3
def test_path_attach_queues():
path = Path()
request_queue = Mock()
response_queue = Mock()
path._attach_queues(request_queue=request_queue, response_queue=response_queue)
assert path._request_queue is request_queue
assert path._response_queue is response_queue
@pytest.mark.parametrize("cls", [LightningFlow, LightningWork])
def test_path_in_flow_and_work(cls, tmpdir):
class PathComponent(cls):
def __init__(self):
super().__init__()
self.path_one = Path("a", "b")
self.path_one = Path("a", "b", "c")
self.path_two = Path(tmpdir) / "write.txt"
def run(self):
self.path_one = self.path_one / "d.txt"
assert self.path_one == Path("a", "b", "c", "d.txt")
with open(self.path_two, "w") as file:
file.write("Hello")
class RootFlow(LightningFlow):
def __init__(self):
super().__init__()
self.path_component = PathComponent()
def run(self):
self.path_component.run()
root = RootFlow()
_ = LightningApp(root) # create an app to convert all paths that got attached
root.run()
assert root.path_component.path_one == Path("a", "b", "c", "d.txt")
assert root.path_component.path_one == pathlib.Path("a", "b", "c", "d.txt")
if isinstance(root.path_component, LightningWork):
assert root.path_component.path_one.origin_name == "root.path_component"
assert root.path_component.path_one.consumer_name == "root.path_component"
else:
assert root.path_component.path_one._origin is None
assert root.path_component.path_one._consumer is None
assert open(root.path_component.path_two).readlines() == ["Hello"]
class SourceWork(LightningWork):
def __init__(self, tmpdir):
super().__init__(cache_calls=True)
self.path = Path(tmpdir, "src.txt")
assert self.path.origin_name == ""
def run(self):
with open(self.path, "w") as f:
f.write("Hello from SourceWork")
class DestinationWork(LightningWork):
def __init__(self, source_path):
super().__init__(cache_calls=True)
assert source_path.origin_name == "root.src_work"
self.path = source_path
assert self.path.origin_name == "root.src_work"
self.other = Path("other")
assert self.other.origin_name == ""
def run(self):
assert self.path.origin_name == "root.src_work"
assert self.other.origin_name == "root.dst_work"
# we are running locally, the file is already there (no transfer needed)
self.path.get(overwrite=True)
assert self.path.is_file()
assert self.path.read_text() == "Hello from SourceWork"
class SourceToDestFlow(LightningFlow):
def __init__(self, tmpdir):
super().__init__()
self.src_work = SourceWork(tmpdir)
self.dst_work = DestinationWork(self.src_work.path)
def run(self):
self.src_work.run()
if self.src_work.has_succeeded:
self.dst_work.run()
if self.dst_work.has_succeeded:
self._exit()
def test_multiprocess_path_in_work_and_flow(tmpdir):
root = SourceToDestFlow(tmpdir)
app = LightningApp(root, debug=True)
MultiProcessRuntime(app, start_server=False).dispatch()
class DynamicSourceToDestFlow(LightningFlow):
def __init__(self, tmpdir):
super().__init__()
self.tmpdir = str(tmpdir)
def run(self):
if not hasattr(self, "src_work"):
self.src_work = SourceWork(self.tmpdir)
self.src_work.run()
if self.src_work.has_succeeded:
if not hasattr(self, "dst_work"):
self.dst_work = DestinationWork(self.src_work.path)
self.dst_work.run()
if hasattr(self, "dst_work") and self.dst_work.has_succeeded:
self._exit()
# FIXME(alecmerdler): This test is failing...
def test_multiprocess_path_in_work_and_flow_dynamic(tmpdir):
root = DynamicSourceToDestFlow(tmpdir)
app = LightningApp(root)
MultiProcessRuntime(app).dispatch()
class RunPathFlow(LightningFlow):
def __init__(self):
super().__init__()
self.src_work = PathSourceWork()
self.run_work = RunPathWork(cache_calls=True)
def run(self):
self.src_work.run()
assert self.src_work.src_path_0.origin_name == "root.src_work"
assert self.src_work.src_path_0.consumer_name == "root.src_work"
# local_path is not attached to any Work
local_path_0 = Path("local", "file_0.txt")
local_path_1 = Path("local", "file_1.txt")
assert local_path_0.origin_name is None
assert local_path_0.consumer_name is None
nested_local_path = (99, {"nested": local_path_1})
nested_kwarg_path = ["x", (self.src_work.src_path_1,)]
# TODO: support returning a path from run()
self.run_work.run(
self.src_work.src_path_0,
local_path_0,
nested_local_path,
kwarg_path=local_path_1,
nested_kwarg_path=nested_kwarg_path,
)
sleep(1)
self._exit()
class PathSourceWork(EmptyWork):
def __init__(self):
super().__init__()
self.src_path_0 = Path("src", "file_0.txt")
self.src_path_1 = Path("src", "file_1.txt")
class RunPathWork(LightningWork):
def run(self, src_path_0, local_path_0, nested_local_path, kwarg_path=None, nested_kwarg_path=None):
all_paths = []
# src_path_0 has an origin which must be preserved, this work becomes consumer
assert str(src_path_0) == os.path.join("src", "file_0.txt")
assert src_path_0.origin_name == "root.src_work"
all_paths.append(src_path_0)
# local_path_0 had no origin, this work becomes both the origin and the consumer
assert str(local_path_0) == os.path.join("local", "file_0.txt")
assert local_path_0.origin_name is None
assert local_path_0.consumer_name is None
all_paths.append(local_path_0)
# nested_local_path is a nested container that contains a Path
assert str(nested_local_path[1]["nested"]) == os.path.join("local", "file_1.txt")
assert nested_local_path[1]["nested"].origin_name is None
assert nested_local_path[1]["nested"].consumer_name is None
all_paths.append(nested_local_path[1]["nested"])
# keywoard arguments can also contain Paths
assert str(kwarg_path) == os.path.join("local", "file_1.txt")
assert kwarg_path.origin_name is None
assert kwarg_path.consumer_name is None
all_paths.append(kwarg_path)
assert str(nested_kwarg_path[1][0]) == os.path.join("src", "file_1.txt")
assert nested_kwarg_path[1][0].origin_name == "root.src_work"
all_paths.append(nested_kwarg_path[1][0])
all(p._request_queue == self._request_queue for p in all_paths)
all(p._response_queue == self._response_queue for p in all_paths)
all(p.consumer_name == self.name == "root.run_work" for p in all_paths)
def test_path_as_argument_to_run_method():
"""Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that
the origin, consumer and queues get automatically attached."""
root = RunPathFlow()
app = LightningApp(root)
MultiProcessRuntime(app, start_server=False).dispatch()
def test_path_get_errors(tmpdir):
with _context("work"):
with pytest.raises(
RuntimeError, match="Trying to get the file .* but the path is not attached to a LightningApp"
):
Path().get()
with pytest.raises(
RuntimeError, match="Trying to get the file .* but the path is not attached to a LightningWork"
):
path = Path()
path._attach_queues(Mock(), Mock())
path.get()
with pytest.raises(FileExistsError, match="The file or folder .* exists locally. Pass `overwrite=True"):
path = Path(tmpdir)
path._attach_queues(Mock(), Mock())
path._attach_work(Mock())
path.get()
class SourceOverwriteWork(LightningWork):
def __init__(self, tmpdir):
super().__init__(raise_exception=True)
self.path = Path(tmpdir, "folder")
def run(self):
self.path.mkdir(parents=True, exist_ok=True)
(self.path / "file.txt").touch()
assert self.path.exists_local()
class DestinationOverwriteWork(LightningWork):
def __init__(self, source_path):
super().__init__(raise_exception=True)
self.path = source_path
def run(self):
assert self.path.exists()
with mock.patch("lightning_app.storage.path.shutil") as shutil_mock:
self.path.get(overwrite=True)
shutil_mock.rmtree.assert_called_with(self.path)
assert self.path.exists()
assert (self.path / "file.txt").exists()
class OverwriteFolderFlow(LightningFlow):
def __init__(self, tmpdir):
super().__init__()
self.src_work = SourceOverwriteWork(tmpdir)
self.dst_work = DestinationOverwriteWork(self.src_work.path)
def run(self):
self.src_work.run()
if self.src_work.has_succeeded:
self.dst_work.run()
if self.dst_work.has_succeeded:
self._exit()
def test_path_get_overwrite(tmpdir):
"""Test that .get(overwrite=True) overwrites the entire directory and replaces all files."""
root = OverwriteFolderFlow(tmpdir)
app = LightningApp(root, debug=True)
MultiProcessRuntime(app, start_server=False).dispatch()
def test_path_get_error_in_flow_context():
with pytest.raises(RuntimeError, match=escape("`Path.get()` can only be called from within the `run()`")):
with _context("flow"):
Path().get()
def test_path_response_with_exception(tmpdir):
request_queue = MockQueue()
response_queue = MockQueue()
path = Path(tmpdir / "file.txt")
path._attach_queues(request_queue, response_queue)
path._origin = "origin"
path._consumer = "consumer"
# simulate that a response will come with an exception raised
response_queue.put(
GetResponse(
source="origin",
path=str(tmpdir / "file.txt"),
hash=path.hash,
destination="consumer",
exception=OSError("Something went wrong"),
name="",
)
)
with pytest.raises(RuntimeError, match="An exception was raised while trying to transfer the contents at"):
with _context("work"):
path.get()
def test_path_response_not_matching_reqeuest(tmpdir):
request_queue = MockQueue()
response_queue = MockQueue()
path = Path(tmpdir / "file.txt")
path._attach_queues(request_queue, response_queue)
path._origin = "origin"
path._consumer = "consumer"
# simulate a response that has a different owner than the request had
response = GetResponse(
source="other_origin", path=str(tmpdir / "file.txt"), hash=path.hash, destination="consumer", name=""
)
response_queue.put(response)
with pytest.raises(
RuntimeError, match="Tried to get the file .* but received a response for a request it did not send."
):
path.get()
# simulate a response that has a different hash than the request had
assert not response_queue
response.path = str(path)
response.hash = "other_hash"
response_queue.put(response)
with pytest.raises(
RuntimeError, match="Tried to get the file .* but received a response for a request it did not send."
):
path.get()
def test_path_exists(tmpdir):
"""Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if
not, send a message to the orchestrator to eventually check the existenc on the origin Work."""
# Local Path (no Work queues attached)
assert not Path("file").exists()
assert Path(tmpdir).exists()
with open(tmpdir / "file", "w"):
assert Path(tmpdir / "file").exists()
# A local path that exists
path = Path(tmpdir)
path.exists_remote = Mock()
path.exists_local = Mock(return_value=True)
assert path.exists() is True
path.exists_local.assert_called_once()
path.exists_remote.assert_not_called() # don't check remotely
# A local path that does not exist, but has no Work attached
path = Path("not-exists.txt")
path.exists_local = Mock(return_value=False)
path.exists_remote = Mock()
assert not path.exists()
path.exists_local.assert_called_once()
path.exists_remote.assert_not_called() # don't check remotely
# A local path that does not exist, but it exists remotely
path = Path("exists-remotely-only.txt")
path.exists_local = Mock(return_value=False)
path.exists_remote = Mock(return_value=True)
path._origin = "origin"
assert path.exists()
path.exists_local.assert_called_once()
path.exists_remote.assert_called_once() # check remotely
def test_path_exists_local(tmpdir):
assert not Path("file").exists_local()
assert Path(tmpdir).exists_local()
with open(tmpdir / "file", "w"):
assert Path(tmpdir / "file").exists_local()
def test_path_exists_remote(tmpdir):
path = Path(tmpdir / "not-attached.txt")
with pytest.raises(RuntimeError, match="the path is not attached to a LightningWork"):
path.exists_remote()
# If Path does not exist locally, ask the orchestrator
request_queue = MockQueue()
response_queue = MockQueue()
path = Path(tmpdir / "not-exists.txt")
path._attach_queues(request_queue, response_queue)
path._origin = "origin"
path._consumer = "consumer"
# Put the response into the queue to simulate the orchestrator responding
response_queue.put(ExistsResponse(source=path.origin_name, path=str(path), name="", hash="123", exists=False))
assert not path.exists_remote()
assert request_queue.get()
response_queue.put(ExistsResponse(source=path.origin_name, path=str(path), name="", hash="123", exists=True))
assert path.exists_remote()
assert request_queue.get()
def test_artifacts_path():
work = Mock()
work.name = "root.flow.work"
assert artifacts_path(work) == shared_storage_path() / "artifacts" / "root.flow.work"
@pytest.mark.skipif(not _is_s3fs_available(), reason="This test requires s3fs.")
@mock.patch.dict(os.environ, {"LIGHTNING_BUCKET_ENDPOINT_URL": "a"})
@mock.patch.dict(os.environ, {"LIGHTNING_BUCKET_NAME": "b"})
@mock.patch.dict(os.environ, {"LIGHTNING_AWS_ACCESS_KEY_ID": "c"})
@mock.patch.dict(os.environ, {"LIGHTNING_AWS_SECRET_ACCESS_KEY": "d"})
@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_APP_ID": "e"})
def test_filesystem(monkeypatch):
from lightning_app.storage import path
mock = MagicMock()
monkeypatch.setattr(path, "S3FileSystem", mock)
fs = filesystem()
assert fs._mock_new_parent._mock_mock_calls[0].kwargs["key"] == "c"
assert fs._mock_new_parent._mock_mock_calls[0].kwargs["secret"] == "d"
assert not fs._mock_new_parent._mock_mock_calls[0].kwargs["use_ssl"]
assert fs._mock_new_parent._mock_mock_calls[0].kwargs["client_kwargs"] == {"endpoint_url": "a"}
class TestSharedStoragePath(TestCase):
@mock.patch.dict(os.environ, {"LIGHTNING_STORAGE_PATH": "test-bucket/lightningapps/test-project/test-app"})
def test_shared_storage_path_storage_path_set(self):
self.assertEqual(pathlib.Path("test-bucket/lightningapps/test-project/test-app"), shared_storage_path())
@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_APP_ID": "test-app", "LIGHTNING_BUCKET_NAME": "test-bucket"})
def test_shared_storage_path_bucket_and_app_id_set(self):
self.assertEqual(pathlib.Path("test-bucket/lightningapps/test-app"), shared_storage_path())
@mock.patch.dict(os.environ, {"SHARED_MOUNT_DIRECTORY": "test-app/.shared"})
def test_shared_storage_path_mount_directory_set(self):
self.assertTrue(shared_storage_path().match("*/test-app/.shared"))
def test_shared_storage_path_no_envvars_set(self):
self.assertTrue(shared_storage_path().match("*/.shared"))