from sentence_transformers import SentenceTransformer | |
from typing import Dict, List, Any, Union | |
class EndpointHandler: | |
def __init__(self, model_path="bge-large-en/"): | |
# Preload all the elements you are going to need at inference | |
self.model = SentenceTransformer(model_path) | |
def __call__(self, data: Dict[str, Any]) -> Union[List[List[float]], List[float]]: | |
""" | |
data args: | |
inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
kwargs | |
Return: | |
A :obj:`list` | `dict`: will be serialized and returned | |
""" | |
# Extracting the inputs and kwargs | |
inputs = data["inputs"] | |
kwargs = data.get("kwargs", {}) | |
normalize_embeddings = kwargs.get('normalize_embeddings', True) | |
# Determine if the input is a query or a passage | |
is_query = kwargs.get("is_query", False) | |
if is_query: | |
instruction = kwargs.get("query_instruction", "") | |
if isinstance(inputs, list): | |
inputs = [instruction + q for q in inputs] | |
else: | |
inputs = instruction + inputs | |
# Encoding the inputs using the model | |
embeddings = self.model.encode(inputs, normalize_embeddings=normalize_embeddings) | |
# Return the serialized embeddings | |
return embeddings.tolist() if isinstance(embeddings, list) else embeddings | |