Spaces:
Sleeping
Sleeping
Rajat.bans
commited on
Commit
•
ee3f636
1
Parent(s):
bf7f11b
Formatted the code
Browse files
rag.py
CHANGED
@@ -107,39 +107,67 @@ The ADS_DATA provided to you is as follows:
|
|
107 |
|
108 |
embeddings_hf = HuggingFaceEmbeddings(model_name=embedding_model_hf)
|
109 |
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if adsData == "":
|
112 |
-
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
|
113 |
-
|
|
|
|
|
|
|
114 |
relation_answer = {"reasoning": "", "classification": 1}
|
115 |
question_answer = {"reasoning": "", "question": "", "options": []}
|
116 |
tokens_used_relation = 0
|
117 |
tokens_used_question = 0
|
118 |
while True:
|
119 |
try:
|
120 |
-
if
|
121 |
-
system_message = {
|
|
|
|
|
|
|
122 |
response = client.chat.completions.create(
|
123 |
model=qa_model_name,
|
124 |
-
messages=[system_message]
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
temperature=0,
|
126 |
seed=42,
|
127 |
max_tokens=1000,
|
128 |
-
response_format={"type": "json_object"
|
129 |
)
|
130 |
tokens_used_relation = response.usage.total_tokens
|
131 |
relation_answer = json.loads(response.choices[0].message.content)
|
132 |
tokens_used_question = 0
|
133 |
-
|
134 |
-
if
|
135 |
-
system_message = {
|
|
|
|
|
|
|
136 |
response = client.chat.completions.create(
|
137 |
model=qa_model_name,
|
138 |
-
messages=[system_message]
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
temperature=0,
|
140 |
seed=42,
|
141 |
max_tokens=1000,
|
142 |
-
response_format={"type": "json_object"
|
143 |
)
|
144 |
tokens_used_question = response.usage.total_tokens
|
145 |
question_answer = json.loads(response.choices[0].message.content)
|
@@ -147,15 +175,19 @@ def getBestQuestionOnTheBasisOfPageInformationAndAdsData(page_information, adsDa
|
|
147 |
except Exception as e:
|
148 |
print("Error-: ", e.message)
|
149 |
print("Trying Again")
|
150 |
-
return (relation_answer, tokens_used_relation), (
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def changeResponseToPrintableString(response, task):
|
153 |
if task == "relation":
|
154 |
return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
|
155 |
res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
|
156 |
-
for option in response[
|
157 |
res += f"{option}\n"
|
158 |
-
for ad in response[
|
159 |
res += f"{ad}\n"
|
160 |
res += "\n"
|
161 |
return res
|
@@ -165,25 +197,34 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
|
|
165 |
curr_relation_prompt = bestRelationSystemPrompt
|
166 |
if RelationPrompt != None or len(RelationPrompt):
|
167 |
curr_relation_prompt = RelationPrompt
|
168 |
-
|
169 |
curr_question_prompt = bestQuestionSystemPrompt
|
170 |
if QuestionPrompt != None or len(QuestionPrompt):
|
171 |
curr_question_prompt = QuestionPrompt
|
172 |
|
173 |
retreived_documents = [
|
174 |
doc
|
175 |
-
for doc in db.similarity_search_with_score(
|
|
|
|
|
176 |
if doc[1] < threshold
|
177 |
]
|
178 |
best_value = 1
|
179 |
if len(retreived_documents):
|
180 |
best_value = retreived_documents[0][1]
|
181 |
-
relation_answer, question_answer =
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
)
|
188 |
print("QUERY:", page_information, relation_answer, question_answer)
|
189 |
docs_info = "\n\n".join(
|
@@ -194,8 +235,12 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
|
|
194 |
]
|
195 |
)
|
196 |
try:
|
197 |
-
relation_answer_string = changeResponseToPrintableString(
|
198 |
-
|
|
|
|
|
|
|
|
|
199 |
full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_answer[1]}\nRelation api call: {relation_answer[1]}"
|
200 |
except:
|
201 |
full_response = f"Invalid response received"
|
@@ -205,9 +250,9 @@ def getRagResponse(RelationPrompt, QuestionPrompt, threshold, page_information):
|
|
205 |
db = FAISS.load_local(
|
206 |
DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
|
207 |
)
|
208 |
-
data = pd.read_csv(data_file_path, sep=
|
209 |
-
data.dropna(axis=0, how=
|
210 |
-
data.drop_duplicates(subset
|
211 |
ad_title_content = list(data["ad_title"].values)
|
212 |
with gr.Blocks() as demo:
|
213 |
gr.Markdown("# RAG on ads data")
|
@@ -227,13 +272,15 @@ with gr.Blocks() as demo:
|
|
227 |
page_information = gr.Textbox(
|
228 |
lines=1, placeholder="Enter the page information", label="Page Information"
|
229 |
)
|
230 |
-
threshold = gr.Number(
|
|
|
|
|
231 |
output = gr.Textbox(label="Output")
|
232 |
submit_btn = gr.Button("Submit")
|
233 |
|
234 |
submit_btn.click(
|
235 |
getRagResponse,
|
236 |
-
inputs=
|
237 |
outputs=[output],
|
238 |
)
|
239 |
page_information.submit(
|
@@ -246,7 +293,10 @@ with gr.Blocks() as demo:
|
|
246 |
|
247 |
demo.load(
|
248 |
lambda: "<br>".join(
|
249 |
-
random.sample(
|
|
|
|
|
|
|
250 |
),
|
251 |
None,
|
252 |
ad_titles,
|
|
|
107 |
|
108 |
embeddings_hf = HuggingFaceEmbeddings(model_name=embedding_model_hf)
|
109 |
|
110 |
+
|
111 |
+
def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
112 |
+
page_information,
|
113 |
+
adsData,
|
114 |
+
relationSystemPrompt,
|
115 |
+
questionSystemPrompt,
|
116 |
+
bestRetreivedAdValue,
|
117 |
+
):
|
118 |
if adsData == "":
|
119 |
+
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
|
120 |
+
{"reasoning": "", "question": "", "options": []},
|
121 |
+
0,
|
122 |
+
)
|
123 |
+
|
124 |
relation_answer = {"reasoning": "", "classification": 1}
|
125 |
question_answer = {"reasoning": "", "question": "", "options": []}
|
126 |
tokens_used_relation = 0
|
127 |
tokens_used_question = 0
|
128 |
while True:
|
129 |
try:
|
130 |
+
if bestRetreivedAdValue > relation_check_best_value_thresh:
|
131 |
+
system_message = {
|
132 |
+
"role": "system",
|
133 |
+
"content": relationSystemPrompt + adsData,
|
134 |
+
}
|
135 |
response = client.chat.completions.create(
|
136 |
model=qa_model_name,
|
137 |
+
messages=[system_message]
|
138 |
+
+ [
|
139 |
+
{
|
140 |
+
"role": "user",
|
141 |
+
"content": page_information + "\nThe JSON response: ",
|
142 |
+
}
|
143 |
+
],
|
144 |
temperature=0,
|
145 |
seed=42,
|
146 |
max_tokens=1000,
|
147 |
+
response_format={"type": "json_object"},
|
148 |
)
|
149 |
tokens_used_relation = response.usage.total_tokens
|
150 |
relation_answer = json.loads(response.choices[0].message.content)
|
151 |
tokens_used_question = 0
|
152 |
+
|
153 |
+
if relation_answer["classification"] != 0:
|
154 |
+
system_message = {
|
155 |
+
"role": "system",
|
156 |
+
"content": questionSystemPrompt + adsData,
|
157 |
+
}
|
158 |
response = client.chat.completions.create(
|
159 |
model=qa_model_name,
|
160 |
+
messages=[system_message]
|
161 |
+
+ [
|
162 |
+
{
|
163 |
+
"role": "user",
|
164 |
+
"content": page_information + "\nThe JSON response: ",
|
165 |
+
}
|
166 |
+
],
|
167 |
temperature=0,
|
168 |
seed=42,
|
169 |
max_tokens=1000,
|
170 |
+
response_format={"type": "json_object"},
|
171 |
)
|
172 |
tokens_used_question = response.usage.total_tokens
|
173 |
question_answer = json.loads(response.choices[0].message.content)
|
|
|
175 |
except Exception as e:
|
176 |
print("Error-: ", e.message)
|
177 |
print("Trying Again")
|
178 |
+
return (relation_answer, tokens_used_relation), (
|
179 |
+
question_answer,
|
180 |
+
tokens_used_question,
|
181 |
+
)
|
182 |
+
|
183 |
|
184 |
def changeResponseToPrintableString(response, task):
|
185 |
if task == "relation":
|
186 |
return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
|
187 |
res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
|
188 |
+
for option in response["options"]:
|
189 |
res += f"{option}\n"
|
190 |
+
for ad in response["options"][option]:
|
191 |
res += f"{ad}\n"
|
192 |
res += "\n"
|
193 |
return res
|
|
|
197 |
curr_relation_prompt = bestRelationSystemPrompt
|
198 |
if RelationPrompt != None or len(RelationPrompt):
|
199 |
curr_relation_prompt = RelationPrompt
|
200 |
+
|
201 |
curr_question_prompt = bestQuestionSystemPrompt
|
202 |
if QuestionPrompt != None or len(QuestionPrompt):
|
203 |
curr_question_prompt = QuestionPrompt
|
204 |
|
205 |
retreived_documents = [
|
206 |
doc
|
207 |
+
for doc in db.similarity_search_with_score(
|
208 |
+
page_information, k=number_of_ads_to_fetch_from_db
|
209 |
+
)
|
210 |
if doc[1] < threshold
|
211 |
]
|
212 |
best_value = 1
|
213 |
if len(retreived_documents):
|
214 |
best_value = retreived_documents[0][1]
|
215 |
+
relation_answer, question_answer = (
|
216 |
+
getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
217 |
+
page_information,
|
218 |
+
".\n".join(
|
219 |
+
[
|
220 |
+
"Ad " + str(i + 1) + ". " + doc[0].page_content
|
221 |
+
for i, doc in enumerate(retreived_documents)
|
222 |
+
]
|
223 |
+
),
|
224 |
+
curr_relation_prompt,
|
225 |
+
curr_question_prompt,
|
226 |
+
best_value,
|
227 |
+
)
|
228 |
)
|
229 |
print("QUERY:", page_information, relation_answer, question_answer)
|
230 |
docs_info = "\n\n".join(
|
|
|
235 |
]
|
236 |
)
|
237 |
try:
|
238 |
+
relation_answer_string = changeResponseToPrintableString(
|
239 |
+
relation_answer[0], "relation"
|
240 |
+
)
|
241 |
+
question_answer_string = changeResponseToPrintableString(
|
242 |
+
question_answer[0], "question"
|
243 |
+
)
|
244 |
full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_answer[1]}\nRelation api call: {relation_answer[1]}"
|
245 |
except:
|
246 |
full_response = f"Invalid response received"
|
|
|
250 |
db = FAISS.load_local(
|
251 |
DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
|
252 |
)
|
253 |
+
data = pd.read_csv(data_file_path, sep="\t")
|
254 |
+
data.dropna(axis=0, how="any", inplace=True)
|
255 |
+
data.drop_duplicates(subset=["ad_title", "ad_desc"], inplace=True)
|
256 |
ad_title_content = list(data["ad_title"].values)
|
257 |
with gr.Blocks() as demo:
|
258 |
gr.Markdown("# RAG on ads data")
|
|
|
272 |
page_information = gr.Textbox(
|
273 |
lines=1, placeholder="Enter the page information", label="Page Information"
|
274 |
)
|
275 |
+
threshold = gr.Number(
|
276 |
+
value=default_threshold, label="Threshold", interactive=True
|
277 |
+
)
|
278 |
output = gr.Textbox(label="Output")
|
279 |
submit_btn = gr.Button("Submit")
|
280 |
|
281 |
submit_btn.click(
|
282 |
getRagResponse,
|
283 |
+
inputs=[RelationPrompt, QuestionPrompt, threshold, page_information],
|
284 |
outputs=[output],
|
285 |
)
|
286 |
page_information.submit(
|
|
|
293 |
|
294 |
demo.load(
|
295 |
lambda: "<br>".join(
|
296 |
+
random.sample(
|
297 |
+
[str(ad_title) for ad_title in ad_title_content],
|
298 |
+
min(100, len(ad_title_content)),
|
299 |
+
)
|
300 |
),
|
301 |
None,
|
302 |
ad_titles,
|