File size: 3,722 Bytes
5c75869
 
 
 
2033a9a
 
5c75869
 
2033a9a
5c75869
2033a9a
48a7936
5c75869
9492db9
5c75869
2033a9a
5c75869
 
 
 
 
 
 
9492db9
5c75869
 
9492db9
5c75869
 
 
 
 
48a7936
5c75869
2033a9a
 
48a7936
 
 
9492db9
2033a9a
48a7936
 
 
2033a9a
 
9492db9
 
5c75869
2033a9a
9492db9
 
48a7936
 
 
 
 
 
 
5c75869
 
 
 
 
 
 
 
 
 
 
 
 
48a7936
 
 
 
 
 
 
5c75869
 
 
e0a5369
5c75869
 
48a7936
5c75869
 
 
48a7936
5c75869
48a7936
 
5c75869
 
 
48a7936
5c75869
 
9492db9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from PIL import Image
import torch
import numpy as np
import spaces
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").cuda()

@spaces.GPU
def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, return_tensors="pt"
    )
    inputs = {k: v.cuda() for k, v in inputs.items()}

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.squeeze().cpu().numpy()  # Squeeze to remove extra dimensions
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.resize(image.size)
    mask = np.array(mask)

    # normalize the mask
    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)
    return mask

@spaces.GPU
def get_masks(prompts, img, threshold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt.strip())  # Strip whitespace from prompts
        mask = mask > threshold
        masks.append(mask)
    return masks

@spaces.GPU
def extract_image(pos_prompts, neg_prompts, img, threshold):
    positive_masks = get_masks(pos_prompts, img, threshold)
    negative_masks = get_masks(neg_prompts, img, threshold)

    # combine masks into one mask, logic OR
    pos_mask = np.any(np.stack(positive_masks), axis=0) if positive_masks else np.zeros_like(img)[:,:,0].astype(bool)
    neg_mask = np.any(np.stack(negative_masks), axis=0) if negative_masks else np.zeros_like(img)[:,:,0].astype(bool)
    final_mask = pos_mask & ~neg_mask

    # extract the final image
    final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
    output_image.paste(img, mask=final_mask)
    return output_image, final_mask

title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"

with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
    gr.Markdown(article)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            positive_prompts = gr.Textbox(
                label="Please describe what you want to identify (comma separated)"
            )
            negative_prompts = gr.Textbox(
                label="Please describe what you want to ignore (comma separated)"
            )

            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            btn_process = gr.Button("Process")

        with gr.Column():
            output_image = gr.Image(label="Result")
            output_mask = gr.Image(label="Mask")

    btn_process.click(
        extract_image,
        inputs=[
            positive_prompts,
            negative_prompts,
            input_image,
            input_slider_T,
        ],
        outputs=[output_image, output_mask],
    )

demo.launch(share=True)