taesiri commited on
Commit
2033a9a
1 Parent(s): e0a5369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -1,19 +1,20 @@
1
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
5
  import matplotlib.pyplot as plt
6
- import torch
7
  import numpy as np
 
 
8
 
9
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
10
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
11
-
12
 
 
13
  def process_image(image, prompt):
14
  inputs = processor(
15
  text=prompt, images=image, padding="max_length", return_tensors="pt"
16
  )
 
17
 
18
  # predict
19
  with torch.no_grad():
@@ -33,22 +34,22 @@ def process_image(image, prompt):
33
  mask = (mask - mask_min) / (mask_max - mask_min)
34
  return mask
35
 
36
-
37
- def get_masks(prompts, img, threhsold):
38
  prompts = prompts.split(",")
39
  masks = []
40
  for prompt in prompts:
41
  mask = process_image(img, prompt)
42
- mask = mask > threhsold
43
  masks.append(mask)
44
  return masks
45
 
46
-
47
- def extract_image(pos_prompts, neg_prompts, img, threhsold):
48
  positive_masks = get_masks(pos_prompts, img, 0.5)
49
  negative_masks = get_masks(neg_prompts, img, 0.5)
50
 
51
- # combine masks into one masks, logic OR
52
  pos_mask = np.any(np.stack(positive_masks), axis=0)
53
  neg_mask = np.any(np.stack(negative_masks), axis=0)
54
  final_mask = pos_mask & ~neg_mask
@@ -59,12 +60,10 @@ def extract_image(pos_prompts, neg_prompts, img, threhsold):
59
  output_image.paste(img, mask=final_mask)
60
  return output_image, final_mask
61
 
62
-
63
  title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
64
  description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
65
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
66
 
67
-
68
  with gr.Blocks() as demo:
69
  gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
70
  gr.Markdown(article)
@@ -100,5 +99,4 @@ with gr.Blocks() as demo:
100
  outputs=[output_image, output_mask],
101
  )
102
 
103
-
104
  demo.launch()
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
  import matplotlib.pyplot as plt
 
5
  import numpy as np
6
+ import spaces
7
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
8
 
9
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
10
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").cuda()
 
11
 
12
+ @spaces.GPU
13
  def process_image(image, prompt):
14
  inputs = processor(
15
  text=prompt, images=image, padding="max_length", return_tensors="pt"
16
  )
17
+ inputs = {k: v.cuda() for k, v in inputs.items()}
18
 
19
  # predict
20
  with torch.no_grad():
 
34
  mask = (mask - mask_min) / (mask_max - mask_min)
35
  return mask
36
 
37
+ @spaces.GPU
38
+ def get_masks(prompts, img, threshold):
39
  prompts = prompts.split(",")
40
  masks = []
41
  for prompt in prompts:
42
  mask = process_image(img, prompt)
43
+ mask = mask > threshold
44
  masks.append(mask)
45
  return masks
46
 
47
+ @spaces.GPU
48
+ def extract_image(pos_prompts, neg_prompts, img, threshold):
49
  positive_masks = get_masks(pos_prompts, img, 0.5)
50
  negative_masks = get_masks(neg_prompts, img, 0.5)
51
 
52
+ # combine masks into one mask, logic OR
53
  pos_mask = np.any(np.stack(positive_masks), axis=0)
54
  neg_mask = np.any(np.stack(negative_masks), axis=0)
55
  final_mask = pos_mask & ~neg_mask
 
60
  output_image.paste(img, mask=final_mask)
61
  return output_image, final_mask
62
 
 
63
  title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
64
  description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
65
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
66
 
 
67
  with gr.Blocks() as demo:
68
  gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
69
  gr.Markdown(article)
 
99
  outputs=[output_image, output_mask],
100
  )
101
 
 
102
  demo.launch()