Spaces:
Running
Running
File size: 3,912 Bytes
1e57a2c 4cbcb94 1e57a2c 3af85d8 1e57a2c e7ded97 1e57a2c 7a3f7ed 84b4358 1e57a2c 8fc12bf d45163a 3ad472b 3af85d8 3ad472b c5ea378 988f448 3af85d8 c5ea378 3af85d8 c5ea378 3ad472b 1e57a2c d45163a 3af85d8 1e57a2c 3af85d8 1e57a2c fdde008 1e57a2c c5ea378 1e57a2c 3af85d8 94bb2b9 1e57a2c d51ec0e 1e57a2c d51ec0e 94bb2b9 1e57a2c d45163a 94bb2b9 7b93622 94bb2b9 3af85d8 94bb2b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import streamlit as st
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from huggingface_hub import hf_hub_download
embedding_path = "abokbot/wikipedia-embedding"
st.header("Wikipedia Search Engine app")
st_model_load = st.text('Loading encoders, embeddings and dataset (takes about 5min)')
@st.cache_resource
def load_embedding():
print("Loading embedding...")
path = hf_hub_download(repo_id="abokbot/wikipedia-embedding", filename="wikipedia_en_embedding.pt")
wikipedia_embedding = torch.load(path, map_location=torch.device('cpu'))
print("Embedding loaded!")
return wikipedia_embedding
wikipedia_embedding = load_embedding()
st_model_load.text("")
@st.cache_resource
def load_encoders():
print("Loading encoders...")
bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
top_k = 32
cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
print("Encoders loaded!")
return bi_encoder, cross_encoder
bi_encoder, cross_encoder = load_encoders()
@st.cache_resource
def load_wikipedia_dataset():
print("Loading wikipedia dataset...")
dataset = load_dataset("abokbot/wikipedia-first-paragraph")["train"]
print("Dataset loaded!")
return dataset
dataset = load_wikipedia_dataset()
st.success('Loading done')
st_model_load.text("")
if 'text' not in st.session_state:
st.session_state.text = ""
st_text_area = st.text_area(
'Enter query (e.g. What is the capital city of Kenya? or Number of deputees in French parliement)',
value=st.session_state.text,
height=100
)
def search():
st.session_state.text = st_text_area
query = st_text_area
print("Input question:", query)
##### Sematic Search #####
print("Semantic Search")
# Encode the query using the bi-encoder and find potentially relevant passages
top_k = 32
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, wikipedia_embedding, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
print("Re-Ranking")
cross_inp = [[query, dataset[hit['corpus_id']]["text"]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
st.session_state.results = hits[:3]
# Output of top-3 hits from re-ranker
print("\n-------------------------\n")
print("Top-3 Cross-Encoder Re-ranker hits")
st.subheader("Top-3 Search results")
results: dict[str, dict] = {}
for i, hit in enumerate(hits[:3]):
results[i] = {
"score": round(hit['cross-score'], 3),
"title": dataset[hit['corpus_id']]["title"],
"abstract": dataset[hit['corpus_id']]["text"].replace("\n", " "),
"link": dataset[hit['corpus_id']]["url"]
}
st.session_state.results = results
# search button
st_search_button = st.button('Search', on_click=search)
if 'results' not in st.session_state:
st.session_state.results = {}
print(st.session_state.results)
if len(st.session_state.results) > 0:
with st.container():
st.subheader("Search results")
for result in st.session_state.results:
for k,v in result.items():
st.markdown("score: " + results["score"])
st.markdown("title: " + results["title"])
st.markdown("abstract: " + results["abstract"])
st.markdown("link: " + results["link"])
st.text("") |