suanan's picture
Update app.py
be77e62 verified
raw
history blame contribute delete
No virus
5.9 kB
import time
import gradio as gr
from datasets import load_dataset
import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
import datetime
# Load titles and texts
title_text_dataset = load_dataset("suanan/BP_CBG_POC", split="train", num_proc=4).select_columns(["url", "title", "text"])
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("index/BP_CBG_int8_usearch_1m_v2.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("index/BP_CBG_ubinary_faiss_1m_v2.index")
# binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("BP_ubinary_ivf_faiss_50m.index")
# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer(
"BAAI/bge-large-zh-v1.5",
prompts={
"retrieval": "Represent this sentence for searching relevant passages: ",
},
default_prompt_name="retrieval",
)
def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: bool = False):
# 獲取當前時間
now = datetime.datetime.now()
print(f"當前時間: {now}, 問題: {query}")
# 1. Embed the query as float32
start_time = time.time()
query_embedding = model.encode(query)
embed_time = time.time() - start_time
# 2. Quantize the query to ubinary
start_time = time.time()
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
quantize_time = time.time() - start_time
# 3. Search the binary index (either exact or approximate)
# index = binary_ivf if use_approx else binary_index
index = binary_index
start_time = time.time()
_scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
binary_ids = binary_ids[0]
search_time = time.time() - start_time
# 4. Load the corresponding int8 embeddings
start_time = time.time()
int8_embeddings = int8_view[binary_ids].astype(int)
load_time = time.time() - start_time
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
start_time = time.time()
scores = query_embedding @ int8_embeddings.T
rescore_time = time.time() - start_time
# 6. Sort the scores and return the top_k
start_time = time.time()
indices = scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
top_k_urls, top_k_titles, top_k_texts = zip(
*[(title_text_dataset[idx]["url"], title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()]
)
df = pd.DataFrame(
{"Score": [round(value, 2) for value in top_k_scores], "Url": top_k_urls, "Title": top_k_titles, "Text": top_k_texts}
)
sort_time = time.time() - start_time
return df, {
"Embed Time": f"{embed_time:.4f} s",
"Quantize Time": f"{quantize_time:.4f} s",
"Search Time": f"{search_time:.4f} s",
"Load Time": f"{load_time:.4f} s",
"Rescore Time": f"{rescore_time:.4f} s",
"Sort Time": f"{sort_time:.4f} s",
"Total search Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
}
def update_info(value):
return f"{value}筆顯示出來"
with gr.Blocks(title="") as demo:
gr.Markdown(
"""
## 官網 Dataset & opensource model BAAI/bge-m3
### v1 測試POC
Details:
1. 中文搜尋ok,英文像是:iphone 15,embedding的時候沒有轉成小寫,需要 寫成iPhone才可以準確搜尋到
2. 環境資源: python 3.10, linux: ubuntu 22.04, only cpu, ram max:7.7GB min:4.5GB 使用以上資源
3.
建立步驟:
1. excel 轉成 [dataset](https://huggingface.co/datasets/suanan/BP_POC) [CBG_dataset](https://huggingface.co/datasets/suanan/BP_CBG_POC), 花費約10秒內
2. dataset 內 轉成 title & text 做 embedding,以後可以新增keyword來加強搜尋出來的結果排序往前
3. 之後透過 Quantized Retrieval - Binary Search solution進行搜尋
"""
)
with gr.Row():
with gr.Column(scale=75):
query = gr.Textbox(
label="官網 Dataset & opensource model BAAI/bge-m3, v1 測試POC",
placeholder="輸入搜尋關鍵字或問句",
)
with gr.Column(scale=25):
use_approx = gr.Radio(
choices=[("精確搜尋", False), ("相關搜尋", True)],
value=False,
label="搜尋方法",
)
with gr.Row():
with gr.Column(scale=2):
top_k = gr.Slider(
minimum=10,
maximum=1000,
step=5,
value=100,
label="顯示搜尋前幾筆",
)
info_text = gr.Textbox(value=update_info(top_k.value), interactive=False)
with gr.Column(scale=2):
rescore_multiplier = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=1,
label="Rescore multiplier",
info="Search for `rescore_multiplier` as many documents to rescore",
)
search_button = gr.Button(value="Search")
with gr.Row():
with gr.Column(scale=4):
output = gr.Dataframe(headers=["Score", "Title", "Text"])
with gr.Column(scale=1):
json = gr.JSON()
top_k.change(fn=update_info, inputs=top_k, outputs=info_text)
query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
demo.queue()
demo.launch(share=True)