DepthGAN / app.py
Harsimran19's picture
Update app.py
a092fbb
raw
history blame
1.5 kB
from transformers import GLPNFeatureExtractor, GLPNForDepthEstimation
import torch
import numpy as np
from PIL import Image
import requests
import gradio as gr
import os
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = GLPNFeatureExtractor.from_pretrained("vinvino02/glpn-nyu")
model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu")
example_list = [["examples/" + example] for example in os.listdir("examples")]
def predict(image):
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth_image = Image.fromarray(formatted)
return depth_image
# Gradio App
title="Image Segmentation GAN"
description="This segments a Normal Image"
demo=gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Image(type='pil'),
title=title ,
examples=example_list,
description=description)
demo.launch(debug=False)