2022-06-30 20:43:04 +00:00
|
|
|
import os
|
|
|
|
from unittest import mock
|
|
|
|
from unittest.mock import ANY
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"LIGHTING_TESTING": "1"})
|
|
|
|
@mock.patch("lightning_app.components.serve.gradio.gradio")
|
|
|
|
def test_serve_gradio(gradio_mock):
|
|
|
|
|
|
|
|
from lightning_app.components.serve.gradio import ServeGradio
|
|
|
|
|
|
|
|
class MyGradioServe(ServeGradio):
|
|
|
|
|
|
|
|
inputs = gradio_mock.inputs.Image(type="pil")
|
|
|
|
outputs = gradio_mock.outputs.Image(type="pil")
|
|
|
|
examples = [["./examples/app_components/serve/gradio/beyonce.png"]]
|
|
|
|
|
|
|
|
def build_model(self):
|
|
|
|
super().build_model()
|
|
|
|
return "model"
|
|
|
|
|
|
|
|
def predict(self, *args, **kwargs):
|
|
|
|
super().predict(*args, **kwargs)
|
|
|
|
return "prediction"
|
|
|
|
|
|
|
|
comp = MyGradioServe()
|
|
|
|
comp.run()
|
|
|
|
assert comp.model == "model"
|
|
|
|
assert comp.predict() == "prediction"
|
2022-11-12 20:48:37 +00:00
|
|
|
gradio_mock.Interface.assert_called_once_with(
|
|
|
|
fn=ANY, inputs=ANY, outputs=ANY, examples=ANY, title=None, description=None
|
|
|
|
)
|