lightning/tests/tests_app/utilities/packaging/test_cloud_compute.py

83 lines
2.6 KiB
Python
Raw Normal View History

import pytest
from lightning_app import CloudCompute
from lightning_app.storage import Mount
def test_cloud_compute_names():
assert CloudCompute().name == "default"
assert CloudCompute("cpu-small").name == "cpu-small"
assert CloudCompute("coconut").name == "coconut" # the backend is responsible for validation of names
def test_cloud_compute_shared_memory():
cloud_compute = CloudCompute("gpu", shm_size=1100)
assert cloud_compute.shm_size == 1100
cloud_compute = CloudCompute("gpu")
assert cloud_compute.shm_size == 1024
cloud_compute = CloudCompute("cpu")
assert cloud_compute.shm_size == 0
def test_cloud_compute_with_mounts():
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
mount_2 = Mount(source="s3://foo/bar/", mount_path="/bar")
cloud_compute = CloudCompute("gpu", mounts=mount_1)
assert cloud_compute.mounts == mount_1
cloud_compute = CloudCompute("gpu", mounts=[mount_1, mount_2])
assert cloud_compute.mounts == [mount_1, mount_2]
cc_dict = cloud_compute.to_dict()
assert "mounts" in cc_dict
assert cc_dict["mounts"] == [
{"mount_path": "/foo", "source": "s3://foo/"},
{"mount_path": "/bar", "source": "s3://foo/bar/"},
]
assert CloudCompute.from_dict(cc_dict) == cloud_compute
def test_cloud_compute_with_non_unique_mount_root_dirs():
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
mount_2 = Mount(source="s3://foo/bar/", mount_path="/foo")
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
CloudCompute("gpu", mounts=[mount_1, mount_2])
def test_cloud_compute_clone():
c1 = CloudCompute("gpu")
c2 = c1.clone()
assert isinstance(c2, CloudCompute)
c1_dict = c1.to_dict()
c2_dict = c2.to_dict()
assert len(c1_dict) == len(c2_dict)
for k in c1_dict.keys():
if k == "_internal_id":
assert c1_dict[k] != c2_dict[k]
else:
assert c1_dict[k] == c2_dict[k]
def test_interruptible(monkeypatch):
"""Test interruptible can be enabled with env variables and for GPU only."""
with pytest.raises(ValueError, match="isn't supported yet"):
CloudCompute("gpu", interruptible=True)
monkeypatch.setenv("LIGHTNING_INTERRUPTIBLE_WORKS", "1")
with pytest.raises(ValueError, match="supported only with GPU"):
CloudCompute("cpu", interruptible=True)
cloud_compute = CloudCompute("gpu", interruptible=True)
assert hasattr(cloud_compute, "interruptible")
# TODO: To be removed once the platform is updated.
assert hasattr(cloud_compute, "preemptible")