File size: 2,420 Bytes
0bfc0a1
0a65b5f
 
 
 
7bc59eb
6fe5ae4
0bfc0a1
 
0a65b5f
 
 
 
 
 
 
 
 
fba9efa
87757c1
fba9efa
 
0a65b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe5ae4
 
 
756b28b
0a65b5f
6fe5ae4
 
 
0a65b5f
6fe5ae4
 
0a65b5f
 
6fe5ae4
0a65b5f
 
6fe5ae4
0a65b5f
6fe5ae4
87757c1
6fe5ae4
 
08eeae0
6fe5ae4
 
ccb3087
4fc2180
ccb3087
0a65b5f
06a19d8
4c0393b
06a19d8
6fe5ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)



def get_bbox(gt_map):
    
    if gt_map.ndim > 2:
        gt_map = gt_map[:, :, 0] 

    # Check if the ground truth map is empty
    if np.sum(gt_map) == 0:
        return [0, 0, gt_map.shape[1], gt_map.shape[0]]

    y_indices, x_indices = np.where(gt_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    H, W = gt_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    bbox = [x_min,y_min,x_max,y_max]

    return bbox


def process_image(image_input):
    # Convert the input to a PIL Image and resize
    image = Image.fromarray(image_input).convert('RGB')
    image = image.resize((256, 256))
    
    # Create a prompt based on the image size
    prompt = [0, 0, image.width, image.height]
    prompt = [[prompt]]  # Modify the prompt to be in the expected format for the processor

    # Process the image and bounding box
    inputs = processor(image, input_boxes=prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Forward pass without gradient calculation
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)

    # Process model output
    seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    seg_prob = seg_prob.cpu().numpy().squeeze()
    seg = (seg_prob > 0.5).astype(np.uint8)

    # Convert numpy arrays back to PIL Images for Gradio output
    seg_image = Image.fromarray(seg * 255)  # Convert boolean mask to uint8 image
    # prob_map = Image.fromarray((seg_prob * 255).astype(np.uint8))  # Scale probabilities to 0-255

    return seg_image


iface = gr.Interface(fn= process_image, inputs="image", outputs="image", title="zerovision")

iface.launch()