SEMat / app.py
XiaRho's picture
Init
8b4c6c7 verified
raw
history blame
4.46 kB
import gradio as gr
from gradio_image_prompter import ImagePrompter
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_choice = {
'SAM': None,
'HQ-SAM': None,
'SAM2': None
}
for model_type in model_choice.keys():
model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model")
def load_model(model_type='SAM2'):
assert model_type in model_choice.keys()
config_path = './configs/SEMat_{}.py'.format(model_type)
cfg = LazyConfig.load(config_path)
if hasattr(cfg.model.sam_model, 'ckpt_path'):
cfg.model.sam_model.ckpt_path = None
else:
cfg.model.sam_model.checkpoint = None
model = instantiate(cfg.model)
if model.lora_rank is not None:
model.init_lora()
model.to(DEVICE)
DetectionCheckpointer(model).load(model_choice[model_type])
model.eval()
return model, model_type
def transform_image_bbox(prompts):
if len(prompts["points"]) != 1:
raise gr.Error("Please input only one BBox.", duration=5)
[[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"]
if idx_3 != 2 or idx_6 != 3:
raise gr.Error("Please input BBox instead of point.", duration=5)
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
img = prompts["image"]
ori_H, ori_W, _ = img.shape
scale = 1024 * 1.0 / max(ori_H, ori_W)
new_H, new_W = ori_H * scale, ori_W * scale
new_W = int(new_W + 0.5)
new_H = int(new_H + 0.5)
img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR)
padding = np.zeros([1024, 1024, 3], dtype=img.dtype)
padding[: new_H, : new_W, :] = img
img = padding
# img = img[:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) / 255.0
img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0
[[x1, y1, _, x2, y2, _]] = prompts["points"]
x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5)
bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0)
return img, bbox, (ori_H, ori_W), (new_H, new_W)
if __name__ == '__main__':
model, model_type = load_model()
def inference_image(prompts, input_model_type):
global model_type
global model
if input_model_type != model_type:
gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5)
_model, _ = load_model(input_model_type)
model_type = input_model_type
model = _model
image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts)
input_data = {
'image': torch.from_numpy(image)[None].to(model.device),
'bbox': torch.from_numpy(bbox)[None].to(model.device),
}
with torch.no_grad():
inputs = model.preprocess_inputs(input_data)
images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition']
if model.backbone_condition:
condition_proj = model.condition_embedding(condition)
elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None:
condition_proj = bbox
else:
condition_proj = None
low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj)
output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255)
return output_alpha
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=45):
img_in = ImagePrompter(type='numpy', show_label=False, label="query image")
with gr.Column(scale=45):
img_out = gr.Image(type='pil', label="output")
with gr.Row():
with gr.Column(scale=45):
input_model_type = gr.Dropdown(list(model_choice.keys()), value='SAM2', label="Trained SEMat Version")
with gr.Column(scale=45):
bt = gr.Button()
bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out])
demo.launch()