wanderer2k1 commited on
Commit
a7b7647
1 Parent(s): d3a77a4
Files changed (2) hide show
  1. SessionState.py +0 -117
  2. app.py +49 -41
SessionState.py DELETED
@@ -1,117 +0,0 @@
1
- """Hack to add per-session state to Streamlit.
2
-
3
- Usage
4
- -----
5
-
6
- >>> import SessionState
7
- >>>
8
- >>> session_state = SessionState.get(user_name='', favorite_color='black')
9
- >>> session_state.user_name
10
- ''
11
- >>> session_state.user_name = 'Mary'
12
- >>> session_state.favorite_color
13
- 'black'
14
-
15
- Since you set user_name above, next time your script runs this will be the
16
- result:
17
- >>> session_state = get(user_name='', favorite_color='black')
18
- >>> session_state.user_name
19
- 'Mary'
20
-
21
- """
22
- try:
23
- import streamlit.ReportThread as ReportThread
24
- from streamlit.server.Server import Server
25
- except Exception:
26
- # Streamlit >= 0.65.0
27
- import streamlit.report_thread as ReportThread
28
- from streamlit.server.server import Server
29
-
30
-
31
- class SessionState(object):
32
- def __init__(self, **kwargs):
33
- """A new SessionState object.
34
-
35
- Parameters
36
- ----------
37
- **kwargs : any
38
- Default values for the session state.
39
-
40
- Example
41
- -------
42
- >>> session_state = SessionState(user_name='', favorite_color='black')
43
- >>> session_state.user_name = 'Mary'
44
- ''
45
- >>> session_state.favorite_color
46
- 'black'
47
-
48
- """
49
- for key, val in kwargs.items():
50
- setattr(self, key, val)
51
-
52
-
53
- def get(**kwargs):
54
- """Gets a SessionState object for the current session.
55
-
56
- Creates a new object if necessary.
57
-
58
- Parameters
59
- ----------
60
- **kwargs : any
61
- Default values you want to add to the session state, if we're creating a
62
- new one.
63
-
64
- Example
65
- -------
66
- >>> session_state = get(user_name='', favorite_color='black')
67
- >>> session_state.user_name
68
- ''
69
- >>> session_state.user_name = 'Mary'
70
- >>> session_state.favorite_color
71
- 'black'
72
-
73
- Since you set user_name above, next time your script runs this will be the
74
- result:
75
- >>> session_state = get(user_name='', favorite_color='black')
76
- >>> session_state.user_name
77
- 'Mary'
78
-
79
- """
80
- # Hack to get the session object from Streamlit.
81
-
82
- ctx = ReportThread.get_report_ctx()
83
-
84
- this_session = None
85
-
86
- current_server = Server.get_current()
87
- if hasattr(current_server, '_session_infos'):
88
- # Streamlit < 0.56
89
- session_infos = Server.get_current()._session_infos.values()
90
- else:
91
- session_infos = Server.get_current()._session_info_by_id.values()
92
-
93
- for session_info in session_infos:
94
- s = session_info.session
95
- if (
96
- # Streamlit < 0.54.0
97
- (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
98
- or
99
- # Streamlit >= 0.54.0
100
- (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
101
- or
102
- # Streamlit >= 0.65.2
103
- (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
104
- ):
105
- this_session = s
106
-
107
- if this_session is None:
108
- raise RuntimeError(
109
- "Oh noes. Couldn't get your Streamlit Session object. "
110
- 'Are you doing something fancy with threads?')
111
-
112
- # Got the session object! Now let's attach some state into it.
113
-
114
- if not hasattr(this_session, '_custom_session_state'):
115
- this_session._custom_session_state = SessionState(**kwargs)
116
-
117
- return this_session._custom_session_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  #basics
 
2
  import time
3
  import pandas as pd
4
  import numpy as np
@@ -13,7 +14,8 @@ from sentence_transformers.util import cos_sim
13
 
14
  #streamlit
15
  import streamlit as st
16
- import SessionState
 
17
  from load_css import local_css
18
  local_css("./style.css")
19
 
@@ -28,9 +30,8 @@ import os.path as path, sys
28
  from pathlib import Path
29
  current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
30
  sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
31
- import src.clean_dataset as clean
32
 
33
- @st.cache(allow_output_mutation=True)
34
 
35
  def preprocess(sentence):
36
  sentence=str(sentence)
@@ -49,14 +50,15 @@ def selectbox_with_default(text, values, default=DEFAULT, sidebar=False):
49
  func = st.sidebar.selectbox if sidebar else st.selectbox
50
  return func(text, np.insert(np.array(values, object), 0, default))
51
 
 
52
  def neuralqa():
53
-
54
  model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
55
  tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")
56
 
57
  bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
58
  return tokenizer, model, bi_encoder
59
 
 
60
  def hf_run_model(tokenizer, model, input_string, **generator_args):
61
  generator_args = {
62
  "max_length": 256,
@@ -73,55 +75,52 @@ def hf_run_model(tokenizer, model, input_string, **generator_args):
73
  output = [item.split("<sep>") for item in output]
74
  return output
75
 
76
-
77
  #%%
78
  sys.path.pop(0)
79
 
80
  #1. load in complete transformed and processed dataset
 
 
 
 
81
 
82
- df = pd.read_csv('./data/corpus.pkl', sep = '\t')
83
- passages = df['text'].values.tolist()
84
- passage_id = df['title'].values.tolist()
85
 
86
  #2 load corpus embeddings for neural QA:
87
- with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp:
88
- embedded_passages = pickle.load(inp)
89
- embedded_passages = torch.Tensor(embedded_passages)
 
90
 
91
  #3 load BM25:
92
- with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp:
93
- bm25 = pickle.load(inp)
 
94
 
95
- #%%
96
- session = SessionState.get(run_id=0)
 
97
 
98
  #%%
99
- #title start page
100
- st.title('Closed Domain (Vietnamese Laws) QA System')
101
 
102
- sdg = Image.open('./logo.jpg')
103
- st.sidebar.image(sdg, width=300)
104
- st.sidebar.title('Settings')
105
 
 
106
 
107
- st.caption("by HoangNV - on custom laws QA data set")
108
- returns = st.sidebar.slider('Maximal number of answer suggestions:', 1, 3, 2)
109
 
110
  def deploy(question):
111
- tokenizer, model, bi_encoder = neuralqa()
112
  top_k = returns # Number of passages we want to retrieve with the bi-encoder
113
 
114
  tokenized_query = preprocess(question).split()
115
  query = ' '.join(tokenized_query)
116
- emb_query = bi_encoder.encode(query)
117
 
118
- scores = bm25.get_scores(tokenized_query)
119
  top_score_ids = np.argpartition(scores, -50)[-50:]
120
 
121
  emb_candidates = torch.Tensor()
122
 
123
  for i in top_score_ids:
124
- emb_candidates = torch.cat([emb_candidates,embedded_passages[i:i+1]], axis = 0)
125
 
126
 
127
  cosine_sim = cos_sim(emb_query, emb_candidates)
@@ -135,14 +134,14 @@ def deploy(question):
135
  answers = []
136
 
137
  for doc_ind in top_score_ids:
138
- doc = passages[doc_ind].replace('_',' ')
139
 
140
  matches.append(doc)#' '.join(doc).replace('_',' '))
141
- ids.append(passage_id[doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
142
  # i=0
143
  for context in matches:
144
  q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
145
- a = hf_run_model(tokenizer, model, q)[0][0]
146
  answers.append(a)
147
 
148
  # generate result df
@@ -157,25 +156,34 @@ def deploy(question):
157
  st.header("Results:")
158
  st.table(df_results)
159
 
160
- del tokenizer, model, bi_encoder#, question_embedding
 
 
 
161
 
162
  #%%
163
- question = st.text_input('Type in your legal question (be as specific as possible):')
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  if len(question) != 0:
166
  t0 = time.time()
167
  with st.spinner('Finding best answers...'):
168
  deploy(question)
169
- st.write(str(time.time()-t0))
170
-
171
- st.write(' ')
172
- st.write(' ')
173
- st.write(' ')
174
- st.write(' ')
175
- st.write(' ')
176
- st.write(' ')
177
- if st.button("Run again!"):
178
- session.run_id += 1
179
 
180
  #%%
181
  p = Path('.')
 
1
  #basics
2
+ from http import server
3
  import time
4
  import pandas as pd
5
  import numpy as np
 
14
 
15
  #streamlit
16
  import streamlit as st
17
+ # from streamlit_server_state import server_state, server_state_lock
18
+ # import SessionState
19
  from load_css import local_css
20
  local_css("./style.css")
21
 
 
30
  from pathlib import Path
31
  current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
32
  sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
33
+ # import src.clean_dataset as clean
34
 
 
35
 
36
  def preprocess(sentence):
37
  sentence=str(sentence)
 
50
  func = st.sidebar.selectbox if sidebar else st.selectbox
51
  return func(text, np.insert(np.array(values, object), 0, default))
52
 
53
+ @st.cache(allow_output_mutation=True)
54
  def neuralqa():
 
55
  model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
56
  tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")
57
 
58
  bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
59
  return tokenizer, model, bi_encoder
60
 
61
+
62
  def hf_run_model(tokenizer, model, input_string, **generator_args):
63
  generator_args = {
64
  "max_length": 256,
 
75
  output = [item.split("<sep>") for item in output]
76
  return output
77
 
 
78
  #%%
79
  sys.path.pop(0)
80
 
81
  #1. load in complete transformed and processed dataset
82
+ if 'df' not in st.session_state:
83
+ st.session_state['df'] = pd.read_csv('./data/corpus.pkl', sep = '\t')
84
+ st.session_state['passages'] = st.session_state['df']['text'].values.tolist()
85
+ st.session_state['passage_id'] = st.session_state['df']['title'].values.tolist()
86
 
 
 
 
87
 
88
  #2 load corpus embeddings for neural QA:
89
+ if 'embedded_passages' not in st.session_state:
90
+ with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp:
91
+ embedded_passages = pickle.load(inp)
92
+ st.session_state['embedded_passages'] = torch.Tensor(embedded_passages)
93
 
94
  #3 load BM25:
95
+ if 'bm25' not in st.session_state:
96
+ with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp:
97
+ st.session_state['bm25'] = pickle.load(inp)
98
 
99
+ #4: model
100
+ if 'model' not in st.session_state:
101
+ st.session_state['tokenizer'], st.session_state['model'], st.session_state['bi_encoder'] = neuralqa()
102
 
103
  #%%
 
 
104
 
 
 
 
105
 
106
+ #%%
107
 
 
 
108
 
109
  def deploy(question):
110
+ # tokenizer, model, bi_encoder = neuralqa()
111
  top_k = returns # Number of passages we want to retrieve with the bi-encoder
112
 
113
  tokenized_query = preprocess(question).split()
114
  query = ' '.join(tokenized_query)
115
+ emb_query = st.session_state['bi_encoder'].encode(query)
116
 
117
+ scores = st.session_state['bm25'].get_scores(tokenized_query)
118
  top_score_ids = np.argpartition(scores, -50)[-50:]
119
 
120
  emb_candidates = torch.Tensor()
121
 
122
  for i in top_score_ids:
123
+ emb_candidates = torch.cat([emb_candidates,st.session_state['embedded_passages'][i:i+1]], axis = 0)
124
 
125
 
126
  cosine_sim = cos_sim(emb_query, emb_candidates)
 
134
  answers = []
135
 
136
  for doc_ind in top_score_ids:
137
+ doc = st.session_state['passages'][doc_ind].replace('_',' ')
138
 
139
  matches.append(doc)#' '.join(doc).replace('_',' '))
140
+ ids.append(st.session_state['passage_id'][doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
141
  # i=0
142
  for context in matches:
143
  q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
144
+ a = hf_run_model(st.session_state['tokenizer'], st.session_state['model'], q)[0][0]
145
  answers.append(a)
146
 
147
  # generate result df
 
156
  st.header("Results:")
157
  st.table(df_results)
158
 
159
+ # del tokenizer, model, bi_encoder, emb_candidates
160
+
161
+
162
+
163
 
164
  #%%
165
+ #title start page
166
+ st.title('Closed Domain (Vietnamese Laws) QA System')
167
+
168
+ sdg = Image.open('./logo.jpg')
169
+ st.sidebar.image(sdg, width=300)
170
+ st.sidebar.title('Settings')
171
+
172
+
173
+ st.caption("by HoangNV - on custom laws QA data set")
174
+ returns = st.sidebar.slider('Number of answer suggestions:', 1, 3, 2)
175
+
176
+
177
+ question = st.text_input('Type in your legal question:')
178
 
179
  if len(question) != 0:
180
  t0 = time.time()
181
  with st.spinner('Finding best answers...'):
182
  deploy(question)
183
+ st.write("Runtime: "+str(time.time()-t0))
184
+
185
+
186
+
 
 
 
 
 
 
187
 
188
  #%%
189
  p = Path('.')