bthndmn12 commited on
Commit
6fe5ae4
1 Parent(s): 08eeae0

fixed some bugs

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from transformers import AutoModel
5
  from transformers import SamModel, SamConfig, SamProcessor
6
  from PIL import Image
 
7
 
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -37,37 +38,42 @@ def get_bbox(gt_map):
37
  return bbox
38
 
39
 
40
-
41
- def greet(image):
42
- image = Image.fromarray(image)
43
  image = image.resize((256, 256))
44
 
45
- gt_mask = np.array(image)
46
- prompt = get_bbox(gt_mask)
 
47
 
48
- inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt")
 
49
  inputs = {k: v.to(device) for k, v in inputs.items()}
50
 
 
51
  model.eval()
52
  with torch.no_grad():
53
- outputs = model(**inputs, multimask_outputs=False)
54
-
55
- seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
56
- seg_prob = seg_prob.cpu().numpy().squeeze()
57
- seg_prob = (seg_prob > 0.5).astype(np.uint8)
58
 
59
- # Ensure the array is 2D (height, width) for grayscale image
60
- if seg_prob.ndim > 2:
61
- seg_prob = seg_prob.squeeze() # Remove extra dimensions if any
62
- elif seg_prob.ndim < 2:
63
- raise ValueError("Output mask has less than 2 dimensions")
64
-
65
- # Convert the processed mask back to a PIL image
66
- seg_prob_image = Image.fromarray(seg_prob)
67
 
68
- return seg_prob_image
 
 
69
 
 
70
 
 
 
 
 
 
 
 
71
 
72
- iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
73
- iface.launch()
 
4
  from transformers import AutoModel
5
  from transformers import SamModel, SamConfig, SamProcessor
6
  from PIL import Image
7
+ import matplotlib.pyplot as plt
8
 
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
38
  return bbox
39
 
40
 
41
+ def process_image(image_input):
42
+ # Convert the input to a PIL Image and resize
43
+ image = Image.fromarray(image_input).convert('RGB')
44
  image = image.resize((256, 256))
45
 
46
+ # Create a prompt based on the image size
47
+ prompt = [0, 0, image.width, image.height]
48
+ prompt = [[prompt]] # Modify the prompt to be in the expected format for the processor
49
 
50
+ # Process the image and bounding box
51
+ inputs = processor(image, input_boxes=prompt, return_tensors="pt")
52
  inputs = {k: v.to(device) for k, v in inputs.items()}
53
 
54
+ # Forward pass without gradient calculation
55
  model.eval()
56
  with torch.no_grad():
57
+ outputs = model(**inputs, multimask_output=False)
 
 
 
 
58
 
59
+ # Process model output
60
+ seg_prob = torch.sigmoid(outputs['pred_masks'].squeeze(1))
61
+ seg_prob = seg_prob.cpu().numpy().squeeze()
62
+ seg = (seg_prob > 0.5).astype(np.uint8)
 
 
 
 
63
 
64
+ # Convert numpy arrays back to PIL Images for Gradio output
65
+ seg_image = Image.fromarray(seg * 255) # Convert boolean mask to uint8 image
66
+ prob_map = Image.fromarray((seg_prob * 255).astype(np.uint8)) # Scale probabilities to 0-255
67
 
68
+ return image, seg_image, prob_map
69
 
70
+ # Define Gradio interface
71
+ iface = gr.Interface(
72
+ fn=process_image,
73
+ inputs=gr.inputs.Image(shape=(256, 256)),
74
+ outputs=[gr.outputs.Image(label="Original Image"), gr.outputs.Image(label="Segmentation Mask"), gr.outputs.Image(label="Probability Map")],
75
+ title="Image Segmentation"
76
+ )
77
 
78
+ # Launch the interface
79
+ iface.launch()