abokbot commited on
Commit
1e57a2c
1 Parent(s): c1b147b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ embedding_path = "abokbot/wikipedia-embedding"
7
+
8
+ st.header("Wikipedia Search Engine app")
9
+
10
+ st_model_load = st.text('Loading wikipedia embedding...')
11
+
12
+ @st.cache_resource
13
+ def load_model():
14
+ print("Loading embedding...")
15
+ hf_hub_download(repo_id="abokbot/wikipedia-embedding", filename="simple_wikipedia_embedding.pt")
16
+ wikipedia_embedding = torch.load("wikipedia-embedding/simple_wikipedia_embedding.pt")
17
+ print("Embedding loaded!")
18
+ return wikipedia_embedding
19
+
20
+ """
21
+
22
+
23
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
24
+ # cf https://www.sbert.net/docs/pretrained-models/msmarco-v3.html
25
+ bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
26
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
27
+ top_k = 32 #Number of passages we want to retrieve with the bi-encoder
28
+
29
+ #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
30
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
31
+
32
+ def search(query):
33
+ print("Input question:", query)
34
+ ##### Sematic Search #####
35
+ # Encode the query using the bi-encoder and find potentially relevant passages
36
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
37
+ question_embedding = question_embedding.cuda()
38
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
39
+ hits = hits[0] # Get the hits for the first query
40
+
41
+ ##### Re-Ranking #####
42
+ # Now, score all retrieved passages with the cross_encoder
43
+ cross_inp = [[query, dataset["text"][hit['corpus_id']]] for hit in hits]
44
+ cross_scores = cross_encoder.predict(cross_inp)
45
+
46
+ # Sort results by the cross-encoder scores
47
+ for idx in range(len(cross_scores)):
48
+ hits[idx]['cross-score'] = cross_scores[idx]
49
+
50
+ # Output of top-3 hits from re-ranker
51
+ print("\n-------------------------\n")
52
+ print("Top-3 Cross-Encoder Re-ranker hits")
53
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
54
+ for hit in hits[0:3]:
55
+ print("score: ", round(hit['cross-score'], 3),"\n",
56
+ "title: ", dataset["title"][hit['corpus_id']], "\n",
57
+ "substract: ", dataset["text"][hit['corpus_id']].replace("\n", " "), "\n",
58
+ "link: ", dataset["url"][hit['corpus_id']],"\n")
59
+
60
+
61
+ """