sadjava's picture
updated model
a741f39
# AUTOGENERATED! DO NOT EDIT! File to edit: ../app.ipynb.
# %% auto 0
__all__ = ['device', 'model', 'MEAN', 'STD', 'transform', 'image', 'label', 'examples', 'intf', 'to_img', 'draw_image_with_bbox',
'localize_dog']
# %% ../app.ipynb 2
from models import Model
import torch
from torchvision import transforms
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import cv2
# %% ../app.ipynb 3
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model()
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model = model.to(device)
model.eval()
# %% ../app.ipynb 4
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# %% ../app.ipynb 5
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
# %% ../app.ipynb 6
def to_img(inp):
mean = np.array(MEAN)
std = np.array(STD)
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp * 255
def draw_image_with_bbox(im, shape, pred_bbox=None, pred_obj=1):
im = im.numpy().transpose((1, 2, 0))
im = cv2.resize(im, dsize=shape)
image_with_bbox = Image.fromarray(to_img(im).astype(np.uint8))
image_draw = ImageDraw.Draw(image_with_bbox)
xc, yc, w, h = pred_bbox
xmin = (xc - w / 2) * shape[0]
ymin = (yc - h / 2) * shape[1]
w = w * shape[0]
h = h * shape[1]
xmin, ymin, w, h = map(int, [xmin, ymin, w, h])
if pred_obj > 0.5:
image_draw.rectangle((max(xmin, 1), max(ymin, 1), min(xmin+w, shape[0] - 1), min(ymin+h, shape[1] - 1)), outline='red')
return image_with_bbox
# %% ../app.ipynb 7
def localize_dog(im):
shape = im.size[:2]
im = im.convert('RGB')
im = transform(im)
pred_label, pred_bbox = model(im.unsqueeze(0).to(device))
prediction = draw_image_with_bbox(im, shape, pred_bbox[0], pred_label[0])
return prediction
# %% ../app.ipynb 9
image = gr.inputs.Image(type="pil")
label = gr.outputs.Image(type="pil")
examples = ['1.jpg', '2.jpg', '3.jpg']
intf = gr.Interface(fn=localize_dog,
inputs=image,
outputs=label,
title='Dog localization',
examples=examples)
intf.launch()