artelabsuper
get model selection field
1513566
raw
history blame
933 Bytes
import gradio as gr
from PIL import Image
import torchvision
import torch
# load model
MODELS_TYPE = ["ModelA", "ModelB", "ModelC"]
def predict(input_image, model_name):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# transform image to torch and do preprocessing
torch_image = torchvision.transforms.ToTensor()(pil_image)
# model predict
prediction = torch.rand(torch_image.shape)
# transform torch to image
predicted_pil_image = torchvision.transforms.ToPILImage()(prediction)
# return correct image
return predicted_pil_image
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(shape=(512,512)),
gr.inputs.Radio(MODELS_TYPE)
],
outputs=gr.Image(shape=(512,512)),
examples=[
["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
],
title="DTM Estimation",
description="This demo predict a DTM..."
)
iface.launch()