ovsam / main.py
HarborYuan's picture
Update main.py
77dc9ae verified
raw
history blame contribute delete
No virus
14.5 kB
import gradio as gr
import numpy as np
from PIL import ImageDraw, Image
import torch
import torch.nn.functional as F
# mm libs
from mmdet.registry import MODELS
from mmengine import Config, print_log
from mmengine.structures import InstanceData
from ext.class_names.lvis_list import LVIS_CLASSES
LVIS_NAMES = LVIS_CLASSES
# Description
title = "<center><strong><font size='8'>Open-Vocabulary SAM<font></strong></center>"
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
model_cfg = Config.fromfile('app/configs/sam_r50x16_fpn.py')
examples = [
["app/assets/sa_01.jpg"],
["app/assets/sa_224028.jpg"],
["app/assets/sa_227490.jpg"],
["app/assets/sa_228025.jpg"],
["app/assets/sa_234958.jpg"],
["app/assets/sa_235005.jpg"],
["app/assets/sa_235032.jpg"],
["app/assets/sa_235036.jpg"],
["app/assets/sa_235086.jpg"],
["app/assets/sa_235094.jpg"],
["app/assets/sa_235113.jpg"],
["app/assets/sa_235130.jpg"],
]
model = MODELS.build(model_cfg.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device)
model = model.eval()
model.init_weights()
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
class IMGState:
def __init__(self):
self.img = None
self.img_feat = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def set_img(self, img, img_feat):
self.img = img
self.img_feat = img_feat
self.available_to_set = False
def clear(self):
self.img = None
self.img_feat = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def clean(self):
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
def to_device(self, device=device):
if self.img_feat is not None:
for k in self.img_feat:
if isinstance(self.img_feat[k], torch.Tensor):
self.img_feat[k] = self.img_feat[k].to(device)
elif isinstance(self.img_feat[k], tuple):
self.img_feat[k] = tuple(v.to(device) for v in self.img_feat[k])
@property
def available(self):
return self.available_to_set
IMG_SIZE = 1024
def get_points_with_draw(image, img_state, evt: gr.SelectData):
label = 'Add Mask'
x, y = evt.index[0], evt.index[1]
print_log(f"Point: {x}_{y}", logger='current')
point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13)
img_state.selected_points.append([x, y])
img_state.selected_points_labels.append(1 if label == "Add Mask" else 0)
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
return img_state, image
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
x, y = evt.index[0], evt.index[1]
point_radius, point_color, box_outline = 5, (237, 34, 13), 2
box_color = (237, 34, 13)
if len(img_state.selected_bboxes) in [0, 1]:
img_state.selected_bboxes.append([x, y])
elif len(img_state.selected_bboxes) == 2:
img_state.selected_bboxes = [[x, y]]
image = Image.fromarray(img_state.img)
else:
raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}")
print_log(f"box_list: {img_state.selected_bboxes}", logger='current')
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
if len(img_state.selected_bboxes) == 2:
box_points = img_state.selected_bboxes
bbox = (min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
draw.rectangle(
bbox,
outline=box_color,
width=box_outline
)
return img_state, image
def segment_with_points(
image,
img_state,
):
if img_state.available:
return None, None, "State Error, please try again."
output_img = img_state.img
h, w = output_img.shape[:2]
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
prompts = InstanceData(
point_coords=input_points[None],
)
try:
img_state.to_device()
masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
img_state.to_device('cpu')
masks = masks[0, 0, :h, :w]
masks = masks > 0.5
cls_pred = cls_pred[0][0]
scores, indices = torch.topk(cls_pred, 1)
scores, indices = scores.tolist(), indices.tolist()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
names = []
for ind in indices:
names.append(LVIS_NAMES[ind].replace('_', ' '))
cls_info = ""
for name, score in zip(names, scores):
cls_info += "{} ({:.2f})".format(name, score)
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
# color[masks] = np.array([217, 90, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
return image, output_img, cls_info
def segment_with_bbox(
image,
img_state
):
if img_state.available:
return None, None, "State Error, please try again."
if len(img_state.selected_bboxes) != 2:
return image, None, ""
output_img = img_state.img
h, w = output_img.shape[:2]
box_points = img_state.selected_bboxes
bbox = (
min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
input_bbox = torch.tensor(bbox, dtype=torch.float32, device=device)
prompts = InstanceData(
bboxes=input_bbox[None],
)
try:
img_state.to_device()
masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
img_state.to_device('cpu')
masks = masks[0, 0, :h, :w]
masks = masks > 0.5
cls_pred = cls_pred[0][0]
scores, indices = torch.topk(cls_pred, 1)
scores, indices = scores.tolist(), indices.tolist()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
names = []
for ind in indices:
names.append(LVIS_NAMES[ind].replace('_', ' '))
cls_info = ""
for name, score in zip(names, scores):
cls_info += "{} ({:.2f})\n".format(name, score)
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
# color[masks] = np.array([217, 90, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
return image, output_img, cls_info
def extract_img_feat(img, img_state):
w, h = img.size
scale = IMG_SIZE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
img_numpy = np.array(img)
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current')
try:
img_tensor = torch.tensor(img_numpy, device=device, dtype=torch.float32).permute((2, 0, 1))[None]
img_tensor = (img_tensor - mean) / std
img_tensor = F.pad(img_tensor, (0, IMG_SIZE - new_w, 0, IMG_SIZE - new_h), 'constant', 0)
feat_dict = model.extract_feat(img_tensor)
img_state.set_img(img_numpy, feat_dict)
img_state.to_device('cpu')
print_log(f"Successfully generated the image feats.", logger='current')
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
return img, None, "Please try to click something."
def clear_everything(img_state):
img_state.clear()
return img_state, None, None, "Please try to click something."
def clean_prompts(img_state):
img_state.clean()
if img_state.img is None:
img_state.clear()
return None, None, "Please try to click something."
return img_state, Image.fromarray(img_state.img), None, "Please try to click something."
def register_point_mode():
img_state_points = gr.State(value=IMGState())
img_state_bbox = gr.State(value=IMGState())
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(title)
# Point mode tab
with gr.Tab("Point mode"):
with gr.Row(variant="panel"):
with gr.Column(scale=1):
cond_img_p = gr.Image(label="Input Image", height=512, type="pil")
with gr.Column(scale=1):
segm_img_p = gr.Image(label="Segment", interactive=False, height=512, type="pil")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
clean_btn_p = gr.Button("Clean Prompts", variant="secondary")
clear_btn_p = gr.Button("Restart", variant="secondary")
with gr.Column():
cls_info = gr.Textbox("", label='Labels')
with gr.Row():
with gr.Column():
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info],
examples_per_page=12,
fn=extract_img_feat,
run_on_click=True,
cache_examples=False,
)
# box mode tab
with gr.Tab("Box mode"):
with gr.Row(variant="panel"):
with gr.Column(scale=1):
cond_img_bbox = gr.Image(label="Input Image", height=512, type="pil")
with gr.Column(scale=1):
segm_img_bbox = gr.Image(label="Segment", interactive=False, height=512, type="pil")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
clean_btn_bbox = gr.Button("Clean Prompts", variant="secondary")
clear_btn_bbox = gr.Button("Restart", variant="secondary")
with gr.Column():
cls_info_bbox = gr.Textbox("", label='Labels')
with gr.Row():
with gr.Column():
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox],
examples_per_page=12,
fn=extract_img_feat,
run_on_click=True,
cache_examples=False,
)
# extract image feature
cond_img_p.upload(
extract_img_feat,
[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info]
)
cond_img_bbox.upload(
extract_img_feat,
[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info]
)
# get user added points
cond_img_p.select(
get_points_with_draw,
[cond_img_p, img_state_points],
outputs=[img_state_points, cond_img_p]
).then(
segment_with_points,
inputs=[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info]
)
cond_img_bbox.select(
get_bbox_with_draw,
[cond_img_bbox, img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox]
).then(
segment_with_bbox,
inputs=[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
# clean prompts
clean_btn_p.click(
clean_prompts,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
clean_btn_bbox.click(
clean_prompts,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
# clear
clear_btn_p.click(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
cond_img_p.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
segm_img_p.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
clear_btn_bbox.click(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
cond_img_bbox.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
segm_img_bbox.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
if __name__ == '__main__':
with gr.Blocks(css=css, title="Open-Vocabulary SAM") as demo:
register_point_mode()
demo.queue()
demo.launch(show_api=False)