|
--- |
|
license: mit |
|
--- |
|
# Segment Anything 8-Bit ONNX |
|
|
|
How to run: |
|
|
|
```python |
|
import onnxruntime as ort |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
|
|
# Path to the image file |
|
image_path = "example.png" |
|
|
|
# Load the image and preprocess it |
|
image = Image.open(image_path).convert("RGB") |
|
orig_width, orig_height = image.size |
|
input_tensor = np.array(image) |
|
mean = np.array([123.675, 116.28, 103.53]) |
|
std = np.array([58.395, 57.12, 57.375]) |
|
input_tensor = (input_tensor - mean) / std |
|
input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32) |
|
|
|
# Pad input tensor to 1024x1024 |
|
pad_height = 1024 - input_tensor.shape[2] |
|
pad_width = 1024 - input_tensor.shape[3] |
|
input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width))) |
|
|
|
# Load the encoder model and run inference |
|
encoder = ort.InferenceSession("sam_encoder.onnx") |
|
embeddings = encoder.run(None, {"images": input_tensor})[0] |
|
|
|
# Choose a point (e.g., x=150, y=100) in the original image |
|
point = [150, 100] |
|
|
|
# Convert point coordinates to match the padded image |
|
point = np.array([[point]]) |
|
coords = point.astype(float) |
|
coords[..., 0] = coords[..., 0] * (1024 / orig_width) |
|
coords[..., 1] = coords[..., 1] * (1024 / orig_height) |
|
onnx_coord = coords.astype("float32") |
|
|
|
# Prepare inputs for the decoder |
|
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) |
|
onnx_has_mask_input = np.zeros(1, dtype=np.float32) |
|
onnx_label = np.array([1, -1]).astype(np.float32)[None, :] |
|
|
|
# Load the decoder model and run inference |
|
decoder = ort.InferenceSession("sam_decoder.onnx") |
|
masks_output, _, _ = decoder.run(None, { |
|
"image_embeddings": embeddings, |
|
"point_coords": onnx_coord, |
|
"point_labels": onnx_label, |
|
"mask_input": onnx_mask_input, |
|
"has_mask_input": onnx_has_mask_input, |
|
"orig_im_size": np.array([orig_height, orig_width], dtype=np.float32) |
|
}) |
|
|
|
# Process the output mask |
|
mask = masks_output[0][0] |
|
mask = (mask > 0).astype('uint8') * 255 |
|
``` |
|
|