lightning/examples/app_components/serve/gradio/app.py

54 lines
1.4 KiB
Python

from functools import partial
import gradio as gr
import requests
import torch
from PIL import Image
import lightning as L
from lightning.app.components import ServeGradio
# Credit to @akhaliq for his inspiring work.
# Find his original code there: https://huggingface.co/spaces/akhaliq/AnimeGANv2/blob/main/app.py
class AnimeGANv2UI(ServeGradio):
inputs = gr.inputs.Image(type="pil")
outputs = gr.outputs.Image(type="pil")
elon = "https://upload.wikimedia.org/wikipedia/commons/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg"
img = Image.open(requests.get(elon, stream=True).raw)
img.save("elon.jpg")
examples = [["elon.jpg"]]
def __init__(self):
super().__init__()
self.ready = False
def predict(self, img):
return self.model(img=img)
def build_model(self):
repo = "AK391/animegan2-pytorch:main"
model = torch.hub.load(repo, "generator", device="cpu")
face2paint = torch.hub.load(repo, "face2paint", size=512, device="cpu")
self.ready = True
return partial(face2paint, model=model)
class RootFlow(L.LightningFlow):
def __init__(self):
super().__init__()
self.demo = AnimeGANv2UI()
def run(self):
self.demo.run()
def configure_layout(self):
tabs = []
if self.demo.ready:
tabs.append({"name": "Home", "content": self.demo})
return tabs
app = L.LightningApp(RootFlow())