lightning/tests/tests_app/launcher/test_lightning_hydrid.py

15 lines
594 B
Python

from unittest import mock
from lightning.app import CloudCompute
from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend
@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
def test_backend_selection(client_mock):
cloud_backend = CloudHybridBackend("", queue_id="")
work = mock.MagicMock()
work.cloud_compute = CloudCompute()
assert cloud_backend._get_backend(work) == cloud_backend.backends["multiprocess"]
work.cloud_compute = CloudCompute("gpu")
assert cloud_backend._get_backend(work) == cloud_backend.backends["cloud"]