taesiri commited on
Commit
9492db9
1 Parent(s): d01a481

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
- import matplotlib.pyplot as plt
5
  import numpy as np
6
  import spaces
7
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
@@ -12,7 +11,7 @@ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined"
12
  @spaces.GPU
13
  def process_image(image, prompt):
14
  inputs = processor(
15
- text=prompt, images=image, padding="max_length", return_tensors="pt"
16
  )
17
  inputs = {k: v.cuda() for k, v in inputs.items()}
18
 
@@ -22,11 +21,10 @@ def process_image(image, prompt):
22
  preds = outputs.logits
23
 
24
  pred = torch.sigmoid(preds)
25
- mat = pred.cpu().numpy()
26
  mask = Image.fromarray(np.uint8(mat * 255), "L")
27
- mask = mask.convert("RGB")
28
  mask = mask.resize(image.size)
29
- mask = np.array(mask)[:, :, 0]
30
 
31
  # normalize the mask
32
  mask_min = mask.min()
@@ -39,19 +37,19 @@ def get_masks(prompts, img, threshold):
39
  prompts = prompts.split(",")
40
  masks = []
41
  for prompt in prompts:
42
- mask = process_image(img, prompt)
43
  mask = mask > threshold
44
  masks.append(mask)
45
  return masks
46
 
47
  @spaces.GPU
48
  def extract_image(pos_prompts, neg_prompts, img, threshold):
49
- positive_masks = get_masks(pos_prompts, img, 0.5)
50
- negative_masks = get_masks(neg_prompts, img, 0.5)
51
 
52
  # combine masks into one mask, logic OR
53
- pos_mask = np.any(np.stack(positive_masks), axis=0)
54
- neg_mask = np.any(np.stack(negative_masks), axis=0)
55
  final_mask = pos_mask & ~neg_mask
56
 
57
  # extract the final image
@@ -99,4 +97,4 @@ with gr.Blocks() as demo:
99
  outputs=[output_image, output_mask],
100
  )
101
 
102
- demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
 
4
  import numpy as np
5
  import spaces
6
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
 
11
  @spaces.GPU
12
  def process_image(image, prompt):
13
  inputs = processor(
14
+ text=prompt, images=image, return_tensors="pt"
15
  )
16
  inputs = {k: v.cuda() for k, v in inputs.items()}
17
 
 
21
  preds = outputs.logits
22
 
23
  pred = torch.sigmoid(preds)
24
+ mat = pred.squeeze().cpu().numpy() # Squeeze to remove extra dimensions
25
  mask = Image.fromarray(np.uint8(mat * 255), "L")
 
26
  mask = mask.resize(image.size)
27
+ mask = np.array(mask)
28
 
29
  # normalize the mask
30
  mask_min = mask.min()
 
37
  prompts = prompts.split(",")
38
  masks = []
39
  for prompt in prompts:
40
+ mask = process_image(img, prompt.strip()) # Strip whitespace from prompts
41
  mask = mask > threshold
42
  masks.append(mask)
43
  return masks
44
 
45
  @spaces.GPU
46
  def extract_image(pos_prompts, neg_prompts, img, threshold):
47
+ positive_masks = get_masks(pos_prompts, img, threshold)
48
+ negative_masks = get_masks(neg_prompts, img, threshold)
49
 
50
  # combine masks into one mask, logic OR
51
+ pos_mask = np.any(np.stack(positive_masks), axis=0) if positive_masks else np.zeros_like(img)[:,:,0].astype(bool)
52
+ neg_mask = np.any(np.stack(negative_masks), axis=0) if negative_masks else np.zeros_like(img)[:,:,0].astype(bool)
53
  final_mask = pos_mask & ~neg_mask
54
 
55
  # extract the final image
 
97
  outputs=[output_image, output_mask],
98
  )
99
 
100
+ demo.launch(share=True)