15 lines
594 B
Python
15 lines
594 B
Python
|
from unittest import mock
|
||
|
|
||
|
from lightning.app import CloudCompute
|
||
|
from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend
|
||
|
|
||
|
|
||
|
@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
|
||
|
def test_backend_selection(client_mock):
|
||
|
cloud_backend = CloudHybridBackend("", queue_id="")
|
||
|
work = mock.MagicMock()
|
||
|
work.cloud_compute = CloudCompute()
|
||
|
assert cloud_backend._get_backend(work) == cloud_backend.backends["multiprocess"]
|
||
|
work.cloud_compute = CloudCompute("gpu")
|
||
|
assert cloud_backend._get_backend(work) == cloud_backend.backends["cloud"]
|