sam-vit-base / handler.py
aradootle's picture
stuff
73cb701
raw
history blame contribute delete
No virus
1.37 kB
from typing import Dict, List, Any
import os
import requests
from flask import Flask, Response, request, jsonify
from segment_anything import SamPredictor, sam_model_registry
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
model_type = "vit_b"
# prefix = "/opt/ml/model"
print('current working directory', os.getcwd())
model_path = "models/tf_model.h5"
# model_checkpoint_path = os.path.join(prefix, "sam_vit_h_4b8939.pth")
sam = sam_model_registry[model_type](checkpoint=model_path)
self.predictor = SamPredictor(sam)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
image_url = inputs.pop("imageUrl", None)
if not image_url:
return jsonify({"error": "image_url not provided"}), 400
try:
response = requests.get(image_url)
response.raise_for_status()
image = response.content
except requests.RequestException as e:
return jsonify({"error": f"Error downloading image: {str(e)}"}), 500
self.predictor.set_image(image)
image_embedding = self.predictor.get_image_embedding().cpu().numpy().tolist()
return jsonify(image_embedding)