Rajat.bans commited on
Commit
ee3f636
1 Parent(s): bf7f11b

Formatted the code

Browse files
Files changed (1) hide show
  1. rag.py +81 -31
rag.py CHANGED
@@ -107,39 +107,67 @@ The ADS_DATA provided to you is as follows:
107
 
108
  embeddings_hf = HuggingFaceEmbeddings(model_name=embedding_model_hf)
109
 
110
- def getBestQuestionOnTheBasisOfPageInformationAndAdsData(page_information, adsData, relationSystemPrompt, questionSystemPrompt, bestRetreivedAdValue):
 
 
 
 
 
 
 
111
  if adsData == "":
112
- return ({"reasoning": "No ads data present", "classification": 0}, 0), ({"reasoning": "", "question": "", "options": []}, 0)
113
-
 
 
 
114
  relation_answer = {"reasoning": "", "classification": 1}
115
  question_answer = {"reasoning": "", "question": "", "options": []}
116
  tokens_used_relation = 0
117
  tokens_used_question = 0
118
  while True:
119
  try:
120
- if (bestRetreivedAdValue > relation_check_best_value_thresh):
121
- system_message = {"role": "system", "content": relationSystemPrompt + adsData}
 
 
 
122
  response = client.chat.completions.create(
123
  model=qa_model_name,
124
- messages=[system_message] + [{"role": "user", "content": page_information + "\nThe JSON response: "}],
 
 
 
 
 
 
125
  temperature=0,
126
  seed=42,
127
  max_tokens=1000,
128
- response_format={"type": "json_object" }
129
  )
130
  tokens_used_relation = response.usage.total_tokens
131
  relation_answer = json.loads(response.choices[0].message.content)
132
  tokens_used_question = 0
133
-
134
- if(relation_answer['classification'] != 0):
135
- system_message = {"role": "system", "content": questionSystemPrompt + adsData}
 
 
 
136
  response = client.chat.completions.create(
137
  model=qa_model_name,
138
- messages=[system_message] + [{"role": "user", "content": page_information + "\nThe JSON response: "}],
 
 
 
 
 
 
139
  temperature=0,
140
  seed=42,
141
  max_tokens=1000,
142
- response_format={"type": "json_object" }
143
  )
144
  tokens_used_question = response.usage.total_tokens
145
  question_answer = json.loads(response.choices[0].message.content)
@@ -147,15 +175,19 @@ def getBestQuestionOnTheBasisOfPageInformationAndAdsData(page_information, adsDa
147
  except Exception as e:
148
  print("Error-: ", e.message)
149
  print("Trying Again")
150
- return (relation_answer, tokens_used_relation), (question_answer, tokens_used_question)
 
 
 
 
151
 
152
  def changeResponseToPrintableString(response, task):
153
  if task == "relation":
154
  return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
155
  res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
156
- for option in response['options']:
157
  res += f"{option}\n"
158
- for ad in response['options'][option]:
159
  res += f"{ad}\n"
160
  res += "\n"
161
  return res
@@ -165,25 +197,34 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
165
  curr_relation_prompt = bestRelationSystemPrompt
166
  if RelationPrompt != None or len(RelationPrompt):
167
  curr_relation_prompt = RelationPrompt
168
-
169
  curr_question_prompt = bestQuestionSystemPrompt
170
  if QuestionPrompt != None or len(QuestionPrompt):
171
  curr_question_prompt = QuestionPrompt
172
 
173
  retreived_documents = [
174
  doc
175
- for doc in db.similarity_search_with_score(page_information, k = number_of_ads_to_fetch_from_db)
 
 
176
  if doc[1] < threshold
177
  ]
178
  best_value = 1
179
  if len(retreived_documents):
180
  best_value = retreived_documents[0][1]
181
- relation_answer, question_answer = getBestQuestionOnTheBasisOfPageInformationAndAdsData(
182
- page_information,
183
- ".\n".join(["Ad " + str(i+1) + ". " + doc[0].page_content for i, doc in enumerate(retreived_documents)]),
184
- curr_relation_prompt,
185
- curr_question_prompt,
186
- best_value
 
 
 
 
 
 
 
187
  )
188
  print("QUERY:", page_information, relation_answer, question_answer)
189
  docs_info = "\n\n".join(
@@ -194,8 +235,12 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
194
  ]
195
  )
196
  try:
197
- relation_answer_string = changeResponseToPrintableString(relation_answer[0], "relation")
198
- question_answer_string = changeResponseToPrintableString(question_answer[0], "question")
 
 
 
 
199
  full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_answer[1]}\nRelation api call: {relation_answer[1]}"
200
  except:
201
  full_response = f"Invalid response received"
@@ -205,9 +250,9 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
205
  db = FAISS.load_local(
206
  DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
207
  )
208
- data = pd.read_csv(data_file_path, sep='\t')
209
- data.dropna(axis=0, how='any', inplace=True)
210
- data.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
211
  ad_title_content = list(data["ad_title"].values)
212
  with gr.Blocks() as demo:
213
  gr.Markdown("# RAG on ads data")
@@ -227,13 +272,15 @@ with gr.Blocks() as demo:
227
  page_information = gr.Textbox(
228
  lines=1, placeholder="Enter the page information", label="Page Information"
229
  )
230
- threshold = gr.Number(value = default_threshold, label="Threshold", interactive=True)
 
 
231
  output = gr.Textbox(label="Output")
232
  submit_btn = gr.Button("Submit")
233
 
234
  submit_btn.click(
235
  getRagResponse,
236
- inputs= [RelationPrompt, QuestionPrompt, threshold, page_information],
237
  outputs=[output],
238
  )
239
  page_information.submit(
@@ -246,7 +293,10 @@ with gr.Blocks() as demo:
246
 
247
  demo.load(
248
  lambda: "<br>".join(
249
- random.sample([str(ad_title) for ad_title in ad_title_content], min(100, len(ad_title_content)))
 
 
 
250
  ),
251
  None,
252
  ad_titles,
 
107
 
108
  embeddings_hf = HuggingFaceEmbeddings(model_name=embedding_model_hf)
109
 
110
+
111
+ def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
112
+ page_information,
113
+ adsData,
114
+ relationSystemPrompt,
115
+ questionSystemPrompt,
116
+ bestRetreivedAdValue,
117
+ ):
118
  if adsData == "":
119
+ return ({"reasoning": "No ads data present", "classification": 0}, 0), (
120
+ {"reasoning": "", "question": "", "options": []},
121
+ 0,
122
+ )
123
+
124
  relation_answer = {"reasoning": "", "classification": 1}
125
  question_answer = {"reasoning": "", "question": "", "options": []}
126
  tokens_used_relation = 0
127
  tokens_used_question = 0
128
  while True:
129
  try:
130
+ if bestRetreivedAdValue > relation_check_best_value_thresh:
131
+ system_message = {
132
+ "role": "system",
133
+ "content": relationSystemPrompt + adsData,
134
+ }
135
  response = client.chat.completions.create(
136
  model=qa_model_name,
137
+ messages=[system_message]
138
+ + [
139
+ {
140
+ "role": "user",
141
+ "content": page_information + "\nThe JSON response: ",
142
+ }
143
+ ],
144
  temperature=0,
145
  seed=42,
146
  max_tokens=1000,
147
+ response_format={"type": "json_object"},
148
  )
149
  tokens_used_relation = response.usage.total_tokens
150
  relation_answer = json.loads(response.choices[0].message.content)
151
  tokens_used_question = 0
152
+
153
+ if relation_answer["classification"] != 0:
154
+ system_message = {
155
+ "role": "system",
156
+ "content": questionSystemPrompt + adsData,
157
+ }
158
  response = client.chat.completions.create(
159
  model=qa_model_name,
160
+ messages=[system_message]
161
+ + [
162
+ {
163
+ "role": "user",
164
+ "content": page_information + "\nThe JSON response: ",
165
+ }
166
+ ],
167
  temperature=0,
168
  seed=42,
169
  max_tokens=1000,
170
+ response_format={"type": "json_object"},
171
  )
172
  tokens_used_question = response.usage.total_tokens
173
  question_answer = json.loads(response.choices[0].message.content)
 
175
  except Exception as e:
176
  print("Error-: ", e.message)
177
  print("Trying Again")
178
+ return (relation_answer, tokens_used_relation), (
179
+ question_answer,
180
+ tokens_used_question,
181
+ )
182
+
183
 
184
  def changeResponseToPrintableString(response, task):
185
  if task == "relation":
186
  return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
187
  res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
188
+ for option in response["options"]:
189
  res += f"{option}\n"
190
+ for ad in response["options"][option]:
191
  res += f"{ad}\n"
192
  res += "\n"
193
  return res
 
197
  curr_relation_prompt = bestRelationSystemPrompt
198
  if RelationPrompt != None or len(RelationPrompt):
199
  curr_relation_prompt = RelationPrompt
200
+
201
  curr_question_prompt = bestQuestionSystemPrompt
202
  if QuestionPrompt != None or len(QuestionPrompt):
203
  curr_question_prompt = QuestionPrompt
204
 
205
  retreived_documents = [
206
  doc
207
+ for doc in db.similarity_search_with_score(
208
+ page_information, k=number_of_ads_to_fetch_from_db
209
+ )
210
  if doc[1] < threshold
211
  ]
212
  best_value = 1
213
  if len(retreived_documents):
214
  best_value = retreived_documents[0][1]
215
+ relation_answer, question_answer = (
216
+ getBestQuestionOnTheBasisOfPageInformationAndAdsData(
217
+ page_information,
218
+ ".\n".join(
219
+ [
220
+ "Ad " + str(i + 1) + ". " + doc[0].page_content
221
+ for i, doc in enumerate(retreived_documents)
222
+ ]
223
+ ),
224
+ curr_relation_prompt,
225
+ curr_question_prompt,
226
+ best_value,
227
+ )
228
  )
229
  print("QUERY:", page_information, relation_answer, question_answer)
230
  docs_info = "\n\n".join(
 
235
  ]
236
  )
237
  try:
238
+ relation_answer_string = changeResponseToPrintableString(
239
+ relation_answer[0], "relation"
240
+ )
241
+ question_answer_string = changeResponseToPrintableString(
242
+ question_answer[0], "question"
243
+ )
244
  full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_answer[1]}\nRelation api call: {relation_answer[1]}"
245
  except:
246
  full_response = f"Invalid response received"
 
250
  db = FAISS.load_local(
251
  DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
252
  )
253
+ data = pd.read_csv(data_file_path, sep="\t")
254
+ data.dropna(axis=0, how="any", inplace=True)
255
+ data.drop_duplicates(subset=["ad_title", "ad_desc"], inplace=True)
256
  ad_title_content = list(data["ad_title"].values)
257
  with gr.Blocks() as demo:
258
  gr.Markdown("# RAG on ads data")
 
272
  page_information = gr.Textbox(
273
  lines=1, placeholder="Enter the page information", label="Page Information"
274
  )
275
+ threshold = gr.Number(
276
+ value=default_threshold, label="Threshold", interactive=True
277
+ )
278
  output = gr.Textbox(label="Output")
279
  submit_btn = gr.Button("Submit")
280
 
281
  submit_btn.click(
282
  getRagResponse,
283
+ inputs=[RelationPrompt, QuestionPrompt, threshold, page_information],
284
  outputs=[output],
285
  )
286
  page_information.submit(
 
293
 
294
  demo.load(
295
  lambda: "<br>".join(
296
+ random.sample(
297
+ [str(ad_title) for ad_title in ad_title_content],
298
+ min(100, len(ad_title_content)),
299
+ )
300
  ),
301
  None,
302
  ad_titles,