robgonsalves commited on
Commit
3b5fa9b
1 Parent(s): 263a219

add code example

Browse files
Files changed (1) hide show
  1. README.md +64 -3
README.md CHANGED
@@ -1,3 +1,64 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Segment Anything 8-Bit ONNX
5
+
6
+ How to run:
7
+
8
+ ```python
9
+ import onnxruntime as ort
10
+ import numpy as np
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
+
14
+ # Path to the image file
15
+ image_path = "example.png"
16
+
17
+ # Load the image and preprocess it
18
+ image = Image.open(image_path).convert("RGB")
19
+ orig_width, orig_height = image.size
20
+ input_tensor = np.array(image)
21
+ mean = np.array([123.675, 116.28, 103.53])
22
+ std = np.array([58.395, 57.12, 57.375])
23
+ input_tensor = (input_tensor - mean) / std
24
+ input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)
25
+
26
+ # Pad input tensor to 1024x1024
27
+ pad_height = 1024 - input_tensor.shape[2]
28
+ pad_width = 1024 - input_tensor.shape[3]
29
+ input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width)))
30
+
31
+ # Load the encoder model and run inference
32
+ encoder = ort.InferenceSession("sam_encoder.onnx")
33
+ embeddings = encoder.run(None, {"images": input_tensor})[0]
34
+
35
+ # Choose a point (e.g., x=150, y=100) in the original image
36
+ point = [150, 100]
37
+
38
+ # Convert point coordinates to match the padded image
39
+ point = np.array([[point]])
40
+ coords = point.astype(float)
41
+ coords[..., 0] = coords[..., 0] * (1024 / orig_width)
42
+ coords[..., 1] = coords[..., 1] * (1024 / orig_height)
43
+ onnx_coord = coords.astype("float32")
44
+
45
+ # Prepare inputs for the decoder
46
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
47
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
48
+ onnx_label = np.array([1, -1]).astype(np.float32)[None, :]
49
+
50
+ # Load the decoder model and run inference
51
+ decoder = ort.InferenceSession("sam_decoder.onnx")
52
+ masks_output, _, _ = decoder.run(None, {
53
+ "image_embeddings": embeddings,
54
+ "point_coords": onnx_coord,
55
+ "point_labels": onnx_label,
56
+ "mask_input": onnx_mask_input,
57
+ "has_mask_input": onnx_has_mask_input,
58
+ "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
59
+ })
60
+
61
+ # Process the output mask
62
+ mask = masks_output[0][0]
63
+ mask = (mask > 0).astype('uint8') * 255
64
+ ```