import gradio as gr from PIL import Image import torchvision import torch # load model MODELS_TYPE = ["ModelA", "ModelB", "ModelC"] def predict(input_image, model_name): pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') # transform image to torch and do preprocessing torch_image = torchvision.transforms.ToTensor()(pil_image) # model predict prediction = torch.rand(torch_image.shape) # transform torch to image predicted_pil_image = torchvision.transforms.ToPILImage()(prediction) # return correct image return predicted_pil_image iface = gr.Interface( fn=predict, inputs=[ gr.Image(shape=(512,512)), gr.inputs.Radio(MODELS_TYPE) ], outputs=gr.Image(shape=(512,512)), examples=[ ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image ], title="DTM Estimation", description="This demo predict a DTM..." ) iface.launch()