quesbook_search / embedding.py
Stefan
fix(spaces): remove types
5b6e243
raw
history blame
1.74 kB
from torch import Tensor
import tiktoken
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")
model = AutoModel.from_pretrained("intfloat/e5-large-v2")
EMBEDDING_CHAR_LIMIT = 512
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def strings_to_vectors(strings):
passage_batch = tokenizer(
strings,
max_length=EMBEDDING_CHAR_LIMIT,
padding=True,
truncation=True,
return_tensors="pt",
)
passage_outputs = model(**passage_batch)
return average_pool(
passage_outputs.last_hidden_state, passage_batch["attention_mask"]
)
def num_tokens_from_str(string, model="gpt-3.5-turbo"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo": # note: future models may deviate from this
num_tokens = 0
num_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
)
num_tokens += len(encoding.encode(string))
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)