83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import pytest
|
|
|
|
from lightning_app import CloudCompute
|
|
from lightning_app.storage import Mount
|
|
|
|
|
|
def test_cloud_compute_names():
|
|
assert CloudCompute().name == "default"
|
|
assert CloudCompute("cpu-small").name == "cpu-small"
|
|
assert CloudCompute("coconut").name == "coconut" # the backend is responsible for validation of names
|
|
|
|
|
|
def test_cloud_compute_shared_memory():
|
|
cloud_compute = CloudCompute("gpu", shm_size=1100)
|
|
assert cloud_compute.shm_size == 1100
|
|
|
|
cloud_compute = CloudCompute("gpu")
|
|
assert cloud_compute.shm_size == 1024
|
|
|
|
cloud_compute = CloudCompute("cpu")
|
|
assert cloud_compute.shm_size == 0
|
|
|
|
|
|
def test_cloud_compute_with_mounts():
|
|
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
|
|
mount_2 = Mount(source="s3://foo/bar/", mount_path="/bar")
|
|
|
|
cloud_compute = CloudCompute("gpu", mounts=mount_1)
|
|
assert cloud_compute.mounts == mount_1
|
|
|
|
cloud_compute = CloudCompute("gpu", mounts=[mount_1, mount_2])
|
|
assert cloud_compute.mounts == [mount_1, mount_2]
|
|
|
|
cc_dict = cloud_compute.to_dict()
|
|
assert "mounts" in cc_dict
|
|
assert cc_dict["mounts"] == [
|
|
{"mount_path": "/foo", "source": "s3://foo/"},
|
|
{"mount_path": "/bar", "source": "s3://foo/bar/"},
|
|
]
|
|
|
|
assert CloudCompute.from_dict(cc_dict) == cloud_compute
|
|
|
|
|
|
def test_cloud_compute_with_non_unique_mount_root_dirs():
|
|
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
|
|
mount_2 = Mount(source="s3://foo/bar/", mount_path="/foo")
|
|
|
|
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
|
|
CloudCompute("gpu", mounts=[mount_1, mount_2])
|
|
|
|
|
|
def test_cloud_compute_clone():
|
|
c1 = CloudCompute("gpu")
|
|
c2 = c1.clone()
|
|
|
|
assert isinstance(c2, CloudCompute)
|
|
|
|
c1_dict = c1.to_dict()
|
|
c2_dict = c2.to_dict()
|
|
|
|
assert len(c1_dict) == len(c2_dict)
|
|
|
|
for k in c1_dict.keys():
|
|
if k == "_internal_id":
|
|
assert c1_dict[k] != c2_dict[k]
|
|
else:
|
|
assert c1_dict[k] == c2_dict[k]
|
|
|
|
|
|
def test_interruptible(monkeypatch):
|
|
"""Test interruptible can be enabled with env variables and for GPU only."""
|
|
with pytest.raises(ValueError, match="isn't supported yet"):
|
|
CloudCompute("gpu", interruptible=True)
|
|
|
|
monkeypatch.setenv("LIGHTNING_INTERRUPTIBLE_WORKS", "1")
|
|
with pytest.raises(ValueError, match="supported only with GPU"):
|
|
CloudCompute("cpu", interruptible=True)
|
|
|
|
cloud_compute = CloudCompute("gpu", interruptible=True)
|
|
assert hasattr(cloud_compute, "interruptible")
|
|
# TODO: To be removed once the platform is updated.
|
|
assert hasattr(cloud_compute, "preemptible")
|