79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import base64
|
|
import multiprocessing as mp
|
|
import os
|
|
from unittest.mock import ANY, MagicMock
|
|
|
|
import pytest
|
|
from tests_app import _PROJECT_ROOT
|
|
|
|
from lightning_app.components.serve import serve
|
|
from lightning_app.utilities.imports import _is_numpy_available, _is_torch_available
|
|
from lightning_app.utilities.network import _configure_session, find_free_network_port
|
|
|
|
if _is_numpy_available():
|
|
import numpy as np
|
|
|
|
if _is_torch_available():
|
|
import torch
|
|
|
|
|
|
class ImageServer(serve.ModelInferenceAPI):
|
|
def build_model(self):
|
|
return lambda x: x
|
|
|
|
def predict(self, image):
|
|
image = self.model(image)
|
|
return torch.from_numpy(np.asarray(image))
|
|
|
|
|
|
def target_fn(port, workers):
|
|
image_server = ImageServer(input="image", output="image", port=port, workers=workers)
|
|
image_server.run()
|
|
|
|
|
|
@pytest.mark.skipif(not (_is_torch_available() and _is_numpy_available()), reason="Missing torch and numpy")
|
|
@pytest.mark.parametrize("workers", [0])
|
|
def test_model_inference_api(workers):
|
|
|
|
port = find_free_network_port()
|
|
process = mp.Process(target=target_fn, args=(port, workers))
|
|
process.start()
|
|
|
|
image_path = os.path.join(_PROJECT_ROOT, "docs/source-app/_static/images/logo.png")
|
|
with open(image_path, "rb") as f:
|
|
imgstr = base64.b64encode(f.read()).decode("UTF-8")
|
|
|
|
session = _configure_session()
|
|
res = session.post(f"http://127.0.0.1:{port}/predict", params={"data": imgstr})
|
|
process.terminate()
|
|
# TODO: Investigate why this doesn't match exactly `imgstr`.
|
|
assert res.json()
|
|
|
|
|
|
class EmptyServer(serve.ModelInferenceAPI):
|
|
def build_model(self):
|
|
return lambda x: x
|
|
|
|
def serialize(self, x):
|
|
return super().serialize(x)
|
|
|
|
def deserialize(self, x):
|
|
return super().deserialize(x)
|
|
|
|
def predict(self, x):
|
|
return super().predict(x)
|
|
|
|
|
|
def test_model_inference_api_mock(monkeypatch):
|
|
|
|
monkeypatch.setattr(serve, "uvicorn", MagicMock())
|
|
comp = EmptyServer()
|
|
comp.run()
|
|
serve.uvicorn.run.assert_called_once_with(app=ANY, host=comp.host, port=comp.port, log_level="error")
|
|
|
|
with pytest.raises(Exception, match="Only input in"):
|
|
EmptyServer(input="something")
|
|
|
|
with pytest.raises(Exception, match="Only output in"):
|
|
EmptyServer(output="something")
|