abokbot commited on
Commit
3af85d8
1 Parent(s): 94bb2b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -8,7 +8,7 @@ embedding_path = "abokbot/wikipedia-embedding"
8
 
9
  st.header("Wikipedia Search Engine app")
10
 
11
- st_model_load = st.text('Loading wikipedia embedding...')
12
 
13
  @st.cache_resource
14
  def load_embedding():
@@ -19,7 +19,6 @@ def load_embedding():
19
  return wikipedia_embedding
20
 
21
  wikipedia_embedding = load_embedding()
22
- st.success('Embedding loaded!')
23
  st_model_load.text("")
24
 
25
  @st.cache_resource
@@ -29,6 +28,7 @@ def load_encoders():
29
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
30
  top_k = 32
31
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
 
32
  return bi_encoder, cross_encoder
33
 
34
  bi_encoder, cross_encoder = load_encoders()
@@ -39,10 +39,11 @@ st_model_load.text("")
39
  def load_wikipedia_dataset():
40
  print("Loading wikipedia dataset...")
41
  dataset = load_dataset("abokbot/wikipedia-first-paragraph")["train"]
 
42
  return dataset
43
 
44
  dataset = load_wikipedia_dataset()
45
- st.success('Datset loaded!')
46
  st_model_load.text("")
47
 
48
  if 'text' not in st.session_state:
@@ -57,7 +58,10 @@ st_text_area = st.text_area(
57
  def search():
58
  st.session_state.text = st_text_area
59
  query = st_text_area
 
 
60
  ##### Sematic Search #####
 
61
  # Encode the query using the bi-encoder and find potentially relevant passages
62
  top_k = 32
63
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
@@ -66,6 +70,7 @@ def search():
66
 
67
  ##### Re-Ranking #####
68
  # Now, score all retrieved passages with the cross_encoder
 
69
  cross_inp = [[query, dataset[hit['corpus_id']]["text"]] for hit in hits]
70
  cross_scores = cross_encoder.predict(cross_inp)
71
 
@@ -99,7 +104,7 @@ if 'results' not in st.session_state:
99
  if len(st.session_state.results) > 0:
100
  with st.container():
101
  st.subheader("Search results")
102
- for result in st.session_state.questions:
103
  for k,v in result.items():
104
  st.markdown("score: " + results["score"])
105
  st.markdown("title: " + results["title"])
 
8
 
9
  st.header("Wikipedia Search Engine app")
10
 
11
+ st_model_load = st.text('Loading encoders, embeddings and dataset (takes about 5min)')
12
 
13
  @st.cache_resource
14
  def load_embedding():
 
19
  return wikipedia_embedding
20
 
21
  wikipedia_embedding = load_embedding()
 
22
  st_model_load.text("")
23
 
24
  @st.cache_resource
 
28
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
29
  top_k = 32
30
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
31
+ print("Encoders loaded!")
32
  return bi_encoder, cross_encoder
33
 
34
  bi_encoder, cross_encoder = load_encoders()
 
39
  def load_wikipedia_dataset():
40
  print("Loading wikipedia dataset...")
41
  dataset = load_dataset("abokbot/wikipedia-first-paragraph")["train"]
42
+ print("Dataset loaded!")
43
  return dataset
44
 
45
  dataset = load_wikipedia_dataset()
46
+ st.success('Loading done')
47
  st_model_load.text("")
48
 
49
  if 'text' not in st.session_state:
 
58
  def search():
59
  st.session_state.text = st_text_area
60
  query = st_text_area
61
+ print("Input question:", query)
62
+
63
  ##### Sematic Search #####
64
+ print("Semantic Search")
65
  # Encode the query using the bi-encoder and find potentially relevant passages
66
  top_k = 32
67
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
 
70
 
71
  ##### Re-Ranking #####
72
  # Now, score all retrieved passages with the cross_encoder
73
+ print("Re-Ranking")
74
  cross_inp = [[query, dataset[hit['corpus_id']]["text"]] for hit in hits]
75
  cross_scores = cross_encoder.predict(cross_inp)
76
 
 
104
  if len(st.session_state.results) > 0:
105
  with st.container():
106
  st.subheader("Search results")
107
+ for result in st.session_state.results:
108
  for k,v in result.items():
109
  st.markdown("score: " + results["score"])
110
  st.markdown("title: " + results["title"])