Spaces:
Sleeping
Sleeping
Rajat.bans
commited on
Commit
•
1c88355
1
Parent(s):
537373b
Upgraded rag code
Browse files- .gitignore +2 -1
- rag.ipynb +225 -0
- rag.py +420 -217
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
**/.DS_Store
|
|
|
|
1 |
+
**/.DS_Store
|
2 |
+
**/__pycache__
|
rag.ipynb
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {
|
7 |
+
"colab": {
|
8 |
+
"base_uri": "https://localhost:8080/"
|
9 |
+
},
|
10 |
+
"id": "SCq5lAKuZxYx",
|
11 |
+
"outputId": "6c44cd5b-efe4-4364-d19c-4650be91f9c6"
|
12 |
+
},
|
13 |
+
"outputs": [
|
14 |
+
{
|
15 |
+
"name": "stdout",
|
16 |
+
"output_type": "stream",
|
17 |
+
"text": [
|
18 |
+
"Mounted at /content/gdrive\n"
|
19 |
+
]
|
20 |
+
}
|
21 |
+
],
|
22 |
+
"source": [
|
23 |
+
"from google.colab import drive\n",
|
24 |
+
"import os\n",
|
25 |
+
"\n",
|
26 |
+
"drive.mount('/content/gdrive')\n",
|
27 |
+
"\n",
|
28 |
+
"!ls\n",
|
29 |
+
"%cd /content/gdrive/MyDrive/rajat.bans/RAG/\n",
|
30 |
+
"!pip install -r requirements.txt"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 1,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [
|
38 |
+
{
|
39 |
+
"name": "stderr",
|
40 |
+
"output_type": "stream",
|
41 |
+
"text": [
|
42 |
+
"/Users/lazyghost/VirtualEnvironments/langchain-rag-venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
43 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
44 |
+
"/Users/lazyghost/VirtualEnvironments/langchain-rag-venv/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
45 |
+
" warnings.warn(\n"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"data": {
|
50 |
+
"text/plain": [
|
51 |
+
"({'relation_answer': {'reasoning': \"No relevant ads found for the user's input 'Hola'.\",\n",
|
52 |
+
" 'classification': 0},\n",
|
53 |
+
" 'tokens_used_relation': 460,\n",
|
54 |
+
" 'question_answer': {'reasoning': '', 'question': '', 'options': []},\n",
|
55 |
+
" 'tokens_used_question': 0},\n",
|
56 |
+
" [])"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
"execution_count": 1,
|
60 |
+
"metadata": {},
|
61 |
+
"output_type": "execute_result"
|
62 |
+
}
|
63 |
+
],
|
64 |
+
"source": [
|
65 |
+
"from rag import VARIABLE_MANAGER\n",
|
66 |
+
"vm = VARIABLE_MANAGER()\n",
|
67 |
+
"rag = vm.getRag()\n",
|
68 |
+
"# data = vm.QnAAdsSampleGenerationPreProcessing()\n",
|
69 |
+
"tot_cost = 0\n",
|
70 |
+
"rag.getRagResponse(\"Hola\")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 59,
|
76 |
+
"metadata": {
|
77 |
+
"colab": {
|
78 |
+
"base_uri": "https://localhost:8080/"
|
79 |
+
},
|
80 |
+
"id": "zzJNW1fCcP3v",
|
81 |
+
"outputId": "6bcb20f5-6596-4e42-ffb7-d3a70566d2e8"
|
82 |
+
},
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"name": "stdout",
|
86 |
+
"output_type": "stream",
|
87 |
+
"text": [
|
88 |
+
"20,"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"name": "stderr",
|
93 |
+
"output_type": "stream",
|
94 |
+
"text": [
|
95 |
+
"/usr/local/lib/python3.10/dist-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
|
96 |
+
" warnings.warn(\n"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "stdout",
|
101 |
+
"output_type": "stream",
|
102 |
+
"text": [
|
103 |
+
" Total cost is up to now is 0.014402400000000001\n"
|
104 |
+
]
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"source": [
|
108 |
+
"import pandas as pd\n",
|
109 |
+
"responses_file_name = './data/147_results_webmd_healthline_12Jun-18Jun_1000each_145BIGQSPRCR_QuestionSystemPromptImprovedClusteringAdded_.tsv'\n",
|
110 |
+
"try:\n",
|
111 |
+
" responses = pd.read_csv(responses_file_name, sep='\\t')\n",
|
112 |
+
"except FileNotFoundError:\n",
|
113 |
+
" responses = pd.DataFrame()\n",
|
114 |
+
"\n",
|
115 |
+
"new_rows = []\n",
|
116 |
+
"for i in range(len(responses), len(data)):\n",
|
117 |
+
" print(i, end = ',')\n",
|
118 |
+
" row = data.iloc[i, :]\n",
|
119 |
+
" try:\n",
|
120 |
+
" answer = {\n",
|
121 |
+
" 'domain_name': row['domain_name'],\n",
|
122 |
+
" 'url': row['url'],\n",
|
123 |
+
" # 'input': '. '.join(row['stripped_url'].split('/')[3:]),\n",
|
124 |
+
" 'kwd_imp': row['kwd_imp'],\n",
|
125 |
+
" 'kwd_click': row['kwd_click'],\n",
|
126 |
+
" 'ad_click': row['ad_click'],\n",
|
127 |
+
" 'revenue': row['revenue'],\n",
|
128 |
+
" 'rank': row['rank'],\n",
|
129 |
+
" 'url_title': row['url_title'],\n",
|
130 |
+
" 'url_content': row['url_content'],\n",
|
131 |
+
" 'input': row['core_content'],\n",
|
132 |
+
" }\n",
|
133 |
+
"\n",
|
134 |
+
" reply, clustered_docs = rag.getRagResponse(row['core_content'])\n",
|
135 |
+
" answer[\"relation_reasoning\"] = reply['relation_answer']['reasoning']\n",
|
136 |
+
" answer[\"relation_classification\"] = reply['relation_answer']['classification']\n",
|
137 |
+
" answer[\"relation_tokens_used\"] = reply['tokens_used_relation']\n",
|
138 |
+
"\n",
|
139 |
+
" answer[\"reasoning\"] = reply['question_answer']['reasoning']\n",
|
140 |
+
" answer[\"question\"] = reply['question_answer']['question']\n",
|
141 |
+
" options = reply['question_answer']['options']\n",
|
142 |
+
" options_res = \"\"\n",
|
143 |
+
" for option in options:\n",
|
144 |
+
" options_res += option + \"\\n\"\n",
|
145 |
+
" for ad in options[option]:\n",
|
146 |
+
" options_res += ad + \"\\n\"\n",
|
147 |
+
" options_res += \"\\n\"\n",
|
148 |
+
" answer[\"options\"] = options_res\n",
|
149 |
+
" answer[\"options_count\"] = str(len(options))\n",
|
150 |
+
" answer[\"question_tokens_used\"] = reply['tokens_used_question']\n",
|
151 |
+
"\n",
|
152 |
+
" ads_data = \"\"\n",
|
153 |
+
" for ind, cluster in enumerate(clustered_docs):\n",
|
154 |
+
" ads_data += f\"*************** Cluster-:{ind+1} **************\\n\"\n",
|
155 |
+
" for doc in cluster:\n",
|
156 |
+
" ad = doc[0]\n",
|
157 |
+
" ads_data += ad.page_content + \"\\n\"\n",
|
158 |
+
" ads_data += \"publisher_url: \" + ad.metadata['publisher_url'] + \" | \"\n",
|
159 |
+
" ads_data += \"keyword_term: \" + ad.metadata['keyword_term'] + \" | \"\n",
|
160 |
+
" ads_data += \"ad_display_url: \" + ad.metadata['ad_display_url'] + \" | \"\n",
|
161 |
+
" ads_data += \"revenue: \" + str(ad.metadata['revenue']) + \" | \"\n",
|
162 |
+
" ads_data += \"ad_click_count: \" + str(ad.metadata['ad_click_count']) + \" | \"\n",
|
163 |
+
" ads_data += \"RPC: \" + str(ad.metadata['RPC']) + \" | \"\n",
|
164 |
+
" ads_data += \"Type: \" + ad.metadata['Type'] + \"\\n\"\n",
|
165 |
+
" ads_data += \"Value: \" + str(doc[1]) + \"\\n\"\n",
|
166 |
+
" ads_data += \"\\n\"\n",
|
167 |
+
" ads_data += \"\\n\"\n",
|
168 |
+
" answer[\"ads_data\"] = ads_data\n",
|
169 |
+
"\n",
|
170 |
+
" cost = (answer[\"relation_tokens_used\"] + answer[\"question_tokens_used\"]) * 0.6/1000000\n",
|
171 |
+
" tot_cost += cost\n",
|
172 |
+
" answer['cost'] = float(cost)\n",
|
173 |
+
" except Exception as e:\n",
|
174 |
+
" print(e)\n",
|
175 |
+
" new_rows.append(answer)\n",
|
176 |
+
"\n",
|
177 |
+
" if i % 10 == 0:\n",
|
178 |
+
" print(\" Total cost is up to now is\", tot_cost)\n",
|
179 |
+
" responses = pd.concat([responses, pd.DataFrame(new_rows)], ignore_index=True)\n",
|
180 |
+
" responses.to_csv(responses_file_name, sep='\\t', index=False)\n",
|
181 |
+
" new_rows = []\n",
|
182 |
+
"\n",
|
183 |
+
"responses = pd.concat([responses, pd.DataFrame(new_rows)], ignore_index=True)\n",
|
184 |
+
"responses.to_csv(responses_file_name, sep='\\t', index=False)\n",
|
185 |
+
"responses\n"
|
186 |
+
]
|
187 |
+
}
|
188 |
+
],
|
189 |
+
"metadata": {
|
190 |
+
"accelerator": "GPU",
|
191 |
+
"colab": {
|
192 |
+
"collapsed_sections": [
|
193 |
+
"5gRHp_nCJHlf",
|
194 |
+
"ScYo9Q38IbGr",
|
195 |
+
"Yd1qWPjlxCTd",
|
196 |
+
"P2soYnTaxE5c",
|
197 |
+
"DOLcMgW6IWX8"
|
198 |
+
],
|
199 |
+
"gpuType": "T4",
|
200 |
+
"machine_shape": "hm",
|
201 |
+
"provenance": []
|
202 |
+
},
|
203 |
+
"kernelspec": {
|
204 |
+
"display_name": "Python 3",
|
205 |
+
"name": "python3"
|
206 |
+
},
|
207 |
+
"language_info": {
|
208 |
+
"codemirror_mode": {
|
209 |
+
"name": "ipython",
|
210 |
+
"version": 3
|
211 |
+
},
|
212 |
+
"file_extension": ".py",
|
213 |
+
"mimetype": "text/x-python",
|
214 |
+
"name": "python",
|
215 |
+
"nbconvert_exporter": "python",
|
216 |
+
"pygments_lexer": "ipython3",
|
217 |
+
"version": "3.12.4"
|
218 |
+
},
|
219 |
+
"widgets": {
|
220 |
+
"application/vnd.jupyter.widget-state+json": {}
|
221 |
+
}
|
222 |
+
},
|
223 |
+
"nbformat": 4,
|
224 |
+
"nbformat_minor": 0
|
225 |
+
}
|
rag.py
CHANGED
@@ -1,39 +1,27 @@
|
|
1 |
-
from dotenv import load_dotenv
|
2 |
-
from langchain_community.vectorstores import FAISS
|
3 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
-
import gradio as gr
|
5 |
-
from openai import OpenAI
|
6 |
-
import random
|
7 |
-
import pandas as pd
|
8 |
-
import os
|
9 |
-
import json
|
10 |
from sklearn.cluster import KMeans, SpectralClustering
|
11 |
from scipy.spatial.distance import euclidean
|
|
|
|
|
12 |
import re
|
13 |
import numpy as np
|
|
|
|
|
14 |
from itertools import count
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
embeddings_hf = HuggingFaceEmbeddings(model_name=embedding_model_hf)
|
21 |
-
|
22 |
|
23 |
class CLUSTERING:
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
)
|
28 |
-
|
29 |
-
def cluster_embeddings(self, embeddings, no_of_clusters, no_of_points):
|
30 |
-
if self.clustering_algo in {"kmeans-cc", "kmeans-sp"}:
|
31 |
-
kmeans = KMeans(n_clusters=no_of_clusters, random_state=42)
|
32 |
kmeans.fit(embeddings)
|
33 |
cluster_centers = kmeans.cluster_centers_
|
34 |
labels = kmeans.labels_
|
35 |
|
36 |
-
if
|
37 |
clusters_indices = [[] for _ in range(no_of_clusters)]
|
38 |
for i, embedding in enumerate(embeddings):
|
39 |
cluster_idx = labels[i]
|
@@ -51,7 +39,7 @@ class CLUSTERING:
|
|
51 |
len(cluster) == no_of_points for cluster in clusters_indices
|
52 |
):
|
53 |
break
|
54 |
-
elif
|
55 |
spectral_clustering = SpectralClustering(
|
56 |
n_clusters=no_of_clusters, affinity="nearest_neighbors", random_state=42
|
57 |
)
|
@@ -67,24 +55,28 @@ class CLUSTERING:
|
|
67 |
[cluster_point[0] for cluster_point in clusters_indices[i][:no_of_points]]
|
68 |
for i in range(no_of_clusters)
|
69 |
]
|
70 |
-
|
71 |
-
|
72 |
class VECTOR_DB:
|
73 |
-
def __init__(self):
|
74 |
-
self.
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
self.no_of_clusters =
|
78 |
-
self.no_of_ads_in_each_cluster =
|
|
|
79 |
self.db = FAISS.load_local(
|
80 |
-
|
81 |
)
|
82 |
|
83 |
-
|
|
|
84 |
def remove_html_tags(text):
|
85 |
clean = re.compile("<.*?>")
|
86 |
return re.sub(clean, "", text)
|
87 |
|
|
|
|
|
|
|
88 |
retreived_documents = [
|
89 |
doc
|
90 |
for doc in self.db.similarity_search_with_score(
|
@@ -98,13 +90,13 @@ class VECTOR_DB:
|
|
98 |
)
|
99 |
if len(retreived_documents):
|
100 |
embeddings = np.array(
|
101 |
-
embeddings_hf.embed_documents(
|
102 |
[doc[0].page_content for doc in retreived_documents]
|
103 |
)
|
104 |
)
|
105 |
|
106 |
clustered_indices = CLUSTERING().cluster_embeddings(
|
107 |
-
embeddings, self.no_of_clusters, self.no_of_ads_in_each_cluster
|
108 |
)
|
109 |
documents_clusters = [
|
110 |
[retreived_documents[ind] for ind in cluster_indices]
|
@@ -115,94 +107,93 @@ class VECTOR_DB:
|
|
115 |
return documents_clusters, best_value
|
116 |
return [], 1
|
117 |
|
118 |
-
|
119 |
-
class ADS_RAG:
|
120 |
def __init__(self):
|
121 |
-
|
122 |
-
self.db = VECTOR_DB()
|
123 |
-
self.qa_model_name = "gpt-3.5-turbo"
|
124 |
-
self.relation_check_best_value_thresh = 0.6
|
125 |
-
self.bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
|
126 |
-
|
127 |
-
---------------------------------------
|
128 |
-
|
129 |
-
**Sample INPUT***: What Causes Bright-Yellow Urine and Other Changes in Color?
|
130 |
-
|
131 |
-
Expected json output :
|
132 |
-
{
|
133 |
-
"reasoning" : "Given the user's search for 'What Causes Bright-Yellow Urine and Other Changes in Color?', it is evident that they are looking for information related to the causes and implications of changes in urine color. Therefore, ads that are related to urine color. However, none of the provided ads are related to urine. Ads 1, 2, and 3 are focused on chronic lymphocytic leukemia (CLL) treatment, Ad 4 addresses high blood pressure, and Ad 5 is about migraine treatment. Since none of these ads are not at all relevant to the urine, they can all be considered irrelevant to the user's intent.",
|
134 |
-
"classification": 0
|
135 |
-
}
|
136 |
-
------------------------------------------------
|
137 |
-
|
138 |
-
**Sample INPUT**: The Effects of Aging on Skin
|
139 |
-
|
140 |
-
Expected json output :
|
141 |
-
{
|
142 |
-
"reasoning" : "Given the user's search for 'The Effects of Aging on Skin,' it is clear that they are seeking information related to skin aging. Therefore, the ads that are relevant to skin effects should be considered. Ads 1 and 2 focus on wrinkle treatment and anti-aging solutions, making them pertinent to the user's intent. Ad 3 targets vitiligo and not general skin aging but it is related to skin effect. So it is also relevant. Ads 4 and 5 are about advanced lung cancer, which do not address the interest in skin. Ads 1 and 2, 3 are most relevant to the user's search. So ADS_DATA is relevant to INPUT. ",
|
143 |
-
"classification": 1
|
144 |
-
}
|
145 |
-
---------------------------------------
|
146 |
-
|
147 |
-
The ADS_DATA provided to you is as follows:
|
148 |
-
"""
|
149 |
-
|
150 |
-
self.bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
|
151 |
-
2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
|
152 |
-
3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
|
153 |
-
4. Provide your REASONING behind choosing the QUESTION and the OPTIONS. Now provide the QUESTION and the OPTIONS. Along with each OPTION, provide the ads from ADS_DATA that you associated with it.
|
154 |
-
|
155 |
-
---------------------------------------
|
156 |
-
|
157 |
-
<Sample INPUT>
|
158 |
-
The Effects of Aging on Skin
|
159 |
-
|
160 |
-
<Sample ADS_DATA>
|
161 |
-
{"Cluster 1 Ads": {"Ad 1": "Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."}, "Cluster 2 Ads": {"Ad 2": "Stop Covering Your Wrinkles with Make Up - Do This Instead.", "Ad 3": "Living With Migraines? - Discover A Treatment Option. Learn about a type of prescription migraine treatment called CGRP receptor antagonists. Discover a range of resources that may help people dealing with migraines"}, "Cluster 3 Ads": {"Ad 4": "What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 5": "Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 6": "Treatment For CKD - Reduce Risk Of Progressing CKD. Ask About A Treatment That Can Help Reduce Your Risk Of Kidney Failure", "Ad 7": "Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."}]
|
162 |
-
|
163 |
-
<Expected json output>
|
164 |
-
{
|
165 |
-
"reasoning" : "Among the seven ads in **Sample ADS_DATA**, Ads 3 and 6 are irrelevant to the INPUT, so they should be discarded. Ad 1, 2 closely aligns with the user's intent. Ads 4, 5, and 7 are also relevant to INPUT. The question will be formed in a way to connect the PAGE content with the goals of these five relevant ads, making sure they appeal to both specific and general user interests, with the OPTIONS being the answer for QUESTION(it is ensured that no irrelevant options are formed)",
|
166 |
-
"question": "Interested in methods to combat aging skin?",
|
167 |
-
"options": {"1. Retinol Alternatives for Wrinkle Treatment." : ["Ad 1: Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."], "2. Reduce Wrinkles without Makeup.": ["Ad 2: Stop Covering Your Wrinkles with Make Up - Do This Instead."], "3. Information on Skin Diseases": ["Ad 3: What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 4: Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 5: Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."]}
|
168 |
-
}
|
169 |
-
-----------------------------------------------
|
170 |
-
|
171 |
-
<Sample INPUT>
|
172 |
-
Got A Rosemary Bush? Here’re 20 Brilliant & Unusual Ways To Use All That Rosemary
|
173 |
-
|
174 |
-
<Sample ADS_DATA>
|
175 |
-
<empty>
|
176 |
-
|
177 |
-
<Expected json output>
|
178 |
-
{
|
179 |
-
"reasoning" : "No ads available",
|
180 |
-
"question": "",
|
181 |
-
"options": []
|
182 |
-
}
|
183 |
-
-----------------------------------------------
|
184 |
-
|
185 |
-
The ADS_DATA provided to you is as follows:
|
186 |
-
"""
|
187 |
-
|
188 |
-
old_system_prompt_additional_example = """
|
189 |
-
-----------------------------------------------
|
190 |
-
<Sample INPUT>
|
191 |
-
7 Signs and Symptoms of Magnesium Deficiency
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
def callOpenAiApi(self, messages):
|
208 |
while True:
|
@@ -232,10 +223,7 @@ Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of P
|
|
232 |
bestRetreivedAdValue,
|
233 |
):
|
234 |
if adsData == "":
|
235 |
-
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
|
236 |
-
{"reasoning": "", "question": "", "options": []},
|
237 |
-
0,
|
238 |
-
)
|
239 |
|
240 |
relation_answer = {"reasoning": "", "classification": 1}
|
241 |
question_answer = {"reasoning": "", "question": "", "options": []}
|
@@ -278,10 +266,12 @@ Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of P
|
|
278 |
}
|
279 |
]
|
280 |
)
|
281 |
-
return
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
285 |
|
286 |
def convertDocumentsClustersToStringForApiCall(self, documents_clusters):
|
287 |
key_counter = count(1)
|
@@ -296,6 +286,31 @@ Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of P
|
|
296 |
)
|
297 |
return res
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
def changeDocumentsToPrintableString(self, documents_clusters):
|
300 |
res = ""
|
301 |
i = 0
|
@@ -323,113 +338,301 @@ Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of P
|
|
323 |
curr_relation_prompt,
|
324 |
curr_question_prompt,
|
325 |
page_information,
|
326 |
-
|
327 |
-
question_answer,
|
328 |
):
|
329 |
print(
|
330 |
"**************************************************************************************************\n",
|
331 |
# curr_relation_prompt,
|
332 |
# curr_question_prompt,
|
333 |
-
page_information,
|
334 |
-
json.dumps(
|
335 |
-
json.dumps(question_answer, indent=4),
|
336 |
"\n************************************************************************************************\n\n",
|
337 |
)
|
338 |
|
339 |
-
def
|
340 |
-
|
341 |
-
):
|
342 |
-
curr_relation_prompt = self.bestRelationSystemPrompt
|
343 |
-
if RelationPrompt != None or len(RelationPrompt):
|
344 |
-
curr_relation_prompt = RelationPrompt
|
345 |
-
|
346 |
-
curr_question_prompt = self.bestQuestionSystemPrompt
|
347 |
-
if QuestionPrompt != None or len(QuestionPrompt):
|
348 |
-
curr_question_prompt = QuestionPrompt
|
349 |
-
|
350 |
-
documents_clusters, best_value = self.db.queryVectorDB(
|
351 |
-
page_information, threshold
|
352 |
-
)
|
353 |
-
relation_answer, question_answer = (
|
354 |
-
self.getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
355 |
-
page_information,
|
356 |
-
self.convertDocumentsClustersToStringForApiCall(documents_clusters),
|
357 |
-
curr_relation_prompt,
|
358 |
-
curr_question_prompt,
|
359 |
-
best_value,
|
360 |
-
)
|
361 |
-
)
|
362 |
self.logResult(
|
363 |
-
|
364 |
-
|
365 |
page_information,
|
366 |
-
|
367 |
-
question_answer,
|
368 |
)
|
369 |
|
370 |
docs_info = self.changeDocumentsToPrintableString(documents_clusters)
|
371 |
relation_answer_string = self.changeResponseToPrintableString(
|
372 |
-
relation_answer
|
373 |
)
|
374 |
question_answer_string = self.changeResponseToPrintableString(
|
375 |
-
question_answer
|
376 |
)
|
377 |
-
|
|
|
|
|
378 |
return full_response
|
379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
-
|
382 |
-
#
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
)
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
)
|
401 |
-
page_information
|
402 |
-
|
|
|
|
|
403 |
)
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
)
|
407 |
-
output = gr.Textbox(label="Output")
|
408 |
-
submit_btn = gr.Button("Submit")
|
409 |
-
|
410 |
-
submit_btn.click(
|
411 |
-
rag.getRagResponse,
|
412 |
-
inputs=[RelationPrompt, QuestionPrompt, threshold, page_information],
|
413 |
-
outputs=[output],
|
414 |
-
)
|
415 |
-
page_information.submit(
|
416 |
-
rag.getRagResponse,
|
417 |
-
inputs=[RelationPrompt, QuestionPrompt, threshold, page_information],
|
418 |
-
outputs=[output],
|
419 |
-
)
|
420 |
-
with gr.Accordion("Ad Titles", open=False):
|
421 |
-
ad_titles = gr.Markdown()
|
422 |
-
|
423 |
-
demo.load(
|
424 |
-
lambda: "<br>".join(
|
425 |
-
random.sample(
|
426 |
-
[str(ad_title) for ad_title in ad_title_content],
|
427 |
-
min(100, len(ad_title_content)),
|
428 |
-
)
|
429 |
-
),
|
430 |
-
None,
|
431 |
-
ad_titles,
|
432 |
-
)
|
433 |
|
434 |
-
gr.close_all()
|
435 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from sklearn.cluster import KMeans, SpectralClustering
|
2 |
from scipy.spatial.distance import euclidean
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
import re
|
6 |
import numpy as np
|
7 |
+
from openai import OpenAI
|
8 |
+
import json
|
9 |
from itertools import count
|
10 |
+
import time
|
11 |
+
import os
|
12 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
import pandas as pd
|
|
|
|
|
15 |
|
16 |
class CLUSTERING:
|
17 |
+
def cluster_embeddings(self, embeddings, clustering_algo, no_of_clusters, no_of_points):
|
18 |
+
if clustering_algo in {"kmeans-cc", "kmeans-sp"}:
|
19 |
+
kmeans = KMeans(n_clusters=min(no_of_clusters, len(embeddings)), random_state=42, n_init="auto")
|
|
|
|
|
|
|
|
|
|
|
20 |
kmeans.fit(embeddings)
|
21 |
cluster_centers = kmeans.cluster_centers_
|
22 |
labels = kmeans.labels_
|
23 |
|
24 |
+
if clustering_algo == "kmeans-cc":
|
25 |
clusters_indices = [[] for _ in range(no_of_clusters)]
|
26 |
for i, embedding in enumerate(embeddings):
|
27 |
cluster_idx = labels[i]
|
|
|
39 |
len(cluster) == no_of_points for cluster in clusters_indices
|
40 |
):
|
41 |
break
|
42 |
+
elif clustering_algo == "spectral":
|
43 |
spectral_clustering = SpectralClustering(
|
44 |
n_clusters=no_of_clusters, affinity="nearest_neighbors", random_state=42
|
45 |
)
|
|
|
55 |
[cluster_point[0] for cluster_point in clusters_indices[i][:no_of_points]]
|
56 |
for i in range(no_of_clusters)
|
57 |
]
|
58 |
+
|
|
|
59 |
class VECTOR_DB:
|
60 |
+
def __init__(self, default_threshold, number_of_ads_to_fetch_from_db, clustering_algo, no_of_clusters, no_of_ads_in_each_cluster, DB_FAISS_PATH, embeddings_hf):
|
61 |
+
self.default_threshold = default_threshold
|
62 |
+
self.number_of_ads_to_fetch_from_db = number_of_ads_to_fetch_from_db
|
63 |
+
self.clustering_algo = clustering_algo # ['kmeans-cc', 'kmeans-sp', 'spectral_clustering']
|
64 |
+
self.no_of_clusters = no_of_clusters
|
65 |
+
self.no_of_ads_in_each_cluster = no_of_ads_in_each_cluster
|
66 |
+
self.embeddings_hf = embeddings_hf
|
67 |
self.db = FAISS.load_local(
|
68 |
+
DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
|
69 |
)
|
70 |
|
71 |
+
|
72 |
+
def queryVectorDB(self, page_information, threshold = None):
|
73 |
def remove_html_tags(text):
|
74 |
clean = re.compile("<.*?>")
|
75 |
return re.sub(clean, "", text)
|
76 |
|
77 |
+
if threshold == None:
|
78 |
+
threshold = self.default_threshold
|
79 |
+
|
80 |
retreived_documents = [
|
81 |
doc
|
82 |
for doc in self.db.similarity_search_with_score(
|
|
|
90 |
)
|
91 |
if len(retreived_documents):
|
92 |
embeddings = np.array(
|
93 |
+
self.embeddings_hf.embed_documents(
|
94 |
[doc[0].page_content for doc in retreived_documents]
|
95 |
)
|
96 |
)
|
97 |
|
98 |
clustered_indices = CLUSTERING().cluster_embeddings(
|
99 |
+
embeddings, self.clustering_algo, self.no_of_clusters, self.no_of_ads_in_each_cluster
|
100 |
)
|
101 |
documents_clusters = [
|
102 |
[retreived_documents[ind] for ind in cluster_indices]
|
|
|
107 |
return documents_clusters, best_value
|
108 |
return [], 1
|
109 |
|
110 |
+
class FAISS_DB:
|
|
|
111 |
def __init__(self):
|
112 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
def createDocs(self, content, metadata, CHUNK_SIZE = 2048, CHUNK_OVERLAP = 512):
|
115 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
116 |
+
chunk_size=CHUNK_SIZE,
|
117 |
+
chunk_overlap=CHUNK_OVERLAP,
|
118 |
+
)
|
119 |
+
split_docs = text_splitter.create_documents(content, metadata)
|
120 |
+
print(f"Documents are split into {len(split_docs)} passages")
|
121 |
+
return split_docs
|
122 |
+
|
123 |
+
def createDBFromDocs(self, split_docs, embeddings_model):
|
124 |
+
db = FAISS.from_documents(split_docs, embeddings_model)
|
125 |
+
return db
|
126 |
+
|
127 |
+
def createAndSaveDBInChunks(self, split_docs, embeddings_model, DB_FAISS_PATH, chunk_size = 1000):
|
128 |
+
one_db_docs_size = chunk_size
|
129 |
+
starting_i = 0
|
130 |
+
for i in range(starting_i, len(split_docs), one_db_docs_size):
|
131 |
+
ctime = time.time()
|
132 |
+
print(i, end = ', ')
|
133 |
+
|
134 |
+
db = FAISS.from_documents(split_docs[i:i+one_db_docs_size], embeddings_model)
|
135 |
+
self.saveDB(db, DB_FAISS_PATH, f"index_{int(i/one_db_docs_size)}")
|
136 |
+
|
137 |
+
ctime = time.time() - ctime
|
138 |
+
print("Time remaining", (ctime / one_db_docs_size * (len(split_docs) - i))/60, "minutes")
|
139 |
+
|
140 |
+
def mergeSecondDbIntoFirst(self, db1, db2):
|
141 |
+
db1.merge_from(db2)
|
142 |
+
|
143 |
+
def saveDB(self, db, DB_FAISS_PATH, index_name = "index"):
|
144 |
+
db.save_local(DB_FAISS_PATH, index_name)
|
145 |
+
|
146 |
+
def readingAndCombining(self, DB_FAISS_PATH, embeddings_hf, index_name_1, index_name_2):
|
147 |
+
db1 = FAISS.load_local(DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True, index_name = index_name_1)
|
148 |
+
db2 = FAISS.load_local(DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True, index_name = index_name_2)
|
149 |
+
db1.merge_from(db2)
|
150 |
+
|
151 |
+
# db1.index
|
152 |
+
# db1.docstore.search(target_id)
|
153 |
+
# len(db0.index_to_docstore_id)
|
154 |
+
|
155 |
+
return db1
|
156 |
+
|
157 |
+
def combineChunksDbs(self, DB_FAISS_PATH, embeddings_hf):
|
158 |
+
files = os.listdir(DB_FAISS_PATH)
|
159 |
+
ind = 0
|
160 |
+
for fl in files:
|
161 |
+
if fl[-6:] == '.faiss':
|
162 |
+
cv = fl[6:-6]
|
163 |
+
ind = max(ind, int(cv))
|
164 |
+
|
165 |
+
all_dbs = []
|
166 |
+
for i in range(0, ind+1, 2):
|
167 |
+
print(i)
|
168 |
+
db1 = FAISS.load_local(DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True, index_name = f'index_{i}')
|
169 |
+
all_dbs.append(db1)
|
170 |
+
if os.path.exists(DB_FAISS_PATH + f"/index_{i+1}.faiss"):
|
171 |
+
db2 = FAISS.load_local(DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True, index_name = f'index_{i+1}')
|
172 |
+
all_dbs.append(db2)
|
173 |
+
|
174 |
+
while(len(all_dbs) != 1):
|
175 |
+
processed_dbs = []
|
176 |
+
print("ITERATION ----------->")
|
177 |
+
for i in range(0, len(all_dbs), 2):
|
178 |
+
db1 = all_dbs[i]
|
179 |
+
print(f"For {i} before length is ", len(db1.index_to_docstore_id), end = ", ")
|
180 |
+
if i+1 != len(all_dbs):
|
181 |
+
db2 = all_dbs[i+1]
|
182 |
+
self.mergeSecondDbIntoFirst(db1, db2)
|
183 |
+
processed_dbs.append(db1)
|
184 |
+
print(f"After length is ", len(db1.index_to_docstore_id))
|
185 |
+
all_dbs = processed_dbs
|
186 |
+
|
187 |
+
return all_dbs[0]
|
188 |
|
189 |
+
class ADS_RAG:
|
190 |
+
def __init__(self, db, qa_model_name, relation_check_best_value_thresh, bestRelationSystemPrompt, bestQuestionSystemPrompt):
|
191 |
+
self.client = OpenAI()
|
192 |
+
self.db = db
|
193 |
+
self.qa_model_name = qa_model_name
|
194 |
+
self.relation_check_best_value_thresh = relation_check_best_value_thresh
|
195 |
+
self.bestRelationSystemPrompt = bestRelationSystemPrompt
|
196 |
+
self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
|
197 |
|
198 |
def callOpenAiApi(self, messages):
|
199 |
while True:
|
|
|
223 |
bestRetreivedAdValue,
|
224 |
):
|
225 |
if adsData == "":
|
226 |
+
return ({"reasoning": "No ads data present", "classification": 0}, 0), ({"reasoning": "", "question": "", "options": []}, 0)
|
|
|
|
|
|
|
227 |
|
228 |
relation_answer = {"reasoning": "", "classification": 1}
|
229 |
question_answer = {"reasoning": "", "question": "", "options": []}
|
|
|
266 |
}
|
267 |
]
|
268 |
)
|
269 |
+
return {
|
270 |
+
"relation_answer": relation_answer,
|
271 |
+
"tokens_used_relation": tokens_used_relation,
|
272 |
+
"question_answer": question_answer,
|
273 |
+
"tokens_used_question": tokens_used_question
|
274 |
+
}
|
275 |
|
276 |
def convertDocumentsClustersToStringForApiCall(self, documents_clusters):
|
277 |
key_counter = count(1)
|
|
|
286 |
)
|
287 |
return res
|
288 |
|
289 |
+
def getRagResponse(
|
290 |
+
self, page_information, threshold = None, RelationPrompt = None, QuestionPrompt = None
|
291 |
+
):
|
292 |
+
curr_relation_prompt = self.bestRelationSystemPrompt
|
293 |
+
if RelationPrompt != None and len(RelationPrompt):
|
294 |
+
curr_relation_prompt = RelationPrompt
|
295 |
+
|
296 |
+
curr_question_prompt = self.bestQuestionSystemPrompt
|
297 |
+
if QuestionPrompt != None and len(QuestionPrompt):
|
298 |
+
curr_question_prompt = QuestionPrompt
|
299 |
+
|
300 |
+
documents_clusters, best_value = self.db.queryVectorDB(
|
301 |
+
page_information, threshold
|
302 |
+
)
|
303 |
+
answer = self.getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
304 |
+
page_information,
|
305 |
+
self.convertDocumentsClustersToStringForApiCall(documents_clusters),
|
306 |
+
curr_relation_prompt,
|
307 |
+
curr_question_prompt,
|
308 |
+
best_value,
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
return answer, documents_clusters
|
313 |
+
|
314 |
def changeDocumentsToPrintableString(self, documents_clusters):
|
315 |
res = ""
|
316 |
i = 0
|
|
|
338 |
curr_relation_prompt,
|
339 |
curr_question_prompt,
|
340 |
page_information,
|
341 |
+
answer
|
|
|
342 |
):
|
343 |
print(
|
344 |
"**************************************************************************************************\n",
|
345 |
# curr_relation_prompt,
|
346 |
# curr_question_prompt,
|
347 |
+
page_information + "\n",
|
348 |
+
json.dumps(answer,indent=4),
|
|
|
349 |
"\n************************************************************************************************\n\n",
|
350 |
)
|
351 |
|
352 |
+
def getRagGradioResponse(self, page_information, RelationPrompt, QuestionPrompt, threshold):
|
353 |
+
answer, documents_clusters = self.getRagResponse(page_information, threshold, RelationPrompt, QuestionPrompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
self.logResult(
|
355 |
+
RelationPrompt,
|
356 |
+
QuestionPrompt,
|
357 |
page_information,
|
358 |
+
answer
|
|
|
359 |
)
|
360 |
|
361 |
docs_info = self.changeDocumentsToPrintableString(documents_clusters)
|
362 |
relation_answer_string = self.changeResponseToPrintableString(
|
363 |
+
answer["relation_answer"], "relation"
|
364 |
)
|
365 |
question_answer_string = self.changeResponseToPrintableString(
|
366 |
+
answer["question_answer"], "question"
|
367 |
)
|
368 |
+
question_tokens = answer["tokens_used_question"]
|
369 |
+
relation_tokens = answer["tokens_used_relation"]
|
370 |
+
full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS CLUSTERS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_tokens}\nRelation api call: {relation_tokens}"
|
371 |
return full_response
|
372 |
|
373 |
+
class VARIABLE_MANAGER:
|
374 |
+
def __init__(self):
|
375 |
+
load_dotenv(override=True)
|
376 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
377 |
+
|
378 |
+
self.embedding_model_hf = "BAAI/bge-m3"
|
379 |
+
# embedding_model_hf = "sentence-transformers/all-mpnet-base-v2"
|
380 |
+
self.DB_FAISS_PATH = "./vectorstore/db_faiss_ads_20May_20Jun_webmd_healthline_Health_dupRemoved0.8"
|
381 |
|
382 |
+
def getRag(self):
|
383 |
+
# embeddings_oa = OpenAIEmbeddings(model=embedding_model_oa)
|
384 |
+
# embeddings_hf = HuggingFaceEmbeddings(model_name = embedding_model_hf, show_progress = True)
|
385 |
+
embeddings_hf = HuggingFaceEmbeddings(model_name=self.embedding_model_hf)
|
386 |
+
vector_db = VECTOR_DB(0.75, 50, "kmeans-cc", 3, 6, self.DB_FAISS_PATH, embeddings_hf)
|
387 |
+
rag = ADS_RAG(vector_db, "gpt-3.5-turbo", 0.6, self.getRelationSystemPrompt(), self.getQuestionSystemPrompt())
|
388 |
+
return rag
|
389 |
+
|
390 |
+
def QnAAdsSampleGenerationPreProcessing(self):
|
391 |
+
data_file_path = "./data/144_webmd_healthline_12Jun-18Jun_top1000each_urlsContent.tsv"
|
392 |
+
data = pd.read_csv(data_file_path, sep='\t')
|
393 |
+
data.dropna(axis=0, how='any', inplace=True)
|
394 |
+
# data.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
|
395 |
+
# ad_title_content = list(data["ad_title"].values)
|
396 |
+
def get_core_content(row):
|
397 |
+
url_content = row['url_content']
|
398 |
+
url_title = row['url_title']
|
399 |
+
return 'Page Title -: ' + url_title + '\nPage Content -: ' + '. '.join(url_content.split('. ')[:7])
|
400 |
+
|
401 |
+
data['core_content'] = data.apply(get_core_content, axis = 1)
|
402 |
+
# for i in range(len(data)):
|
403 |
+
# print(data.loc[i, 'url'])
|
404 |
+
# print(data.loc[i, 'url_content'])
|
405 |
+
# print(data.loc[i, 'core_content'])
|
406 |
+
# print()
|
407 |
+
# if(i > 10):
|
408 |
+
# break
|
409 |
+
return data
|
410 |
+
|
411 |
+
def GradioRagPreProcessing(self):
|
412 |
+
data_file_path = "./data/142_adclick_20May_20Jun_webmd_healthline_Health_dupRemoved0.8_someAdsCampaign.tsv"
|
413 |
+
data = pd.read_csv(data_file_path, sep="\t")
|
414 |
+
# data.dropna(axis=0, how="any", inplace=True)
|
415 |
+
data.drop_duplicates(subset=["ad_title", "ad_desc"], inplace=True)
|
416 |
+
ad_title_content = list(data["ad_title"].values)
|
417 |
+
return ad_title_content
|
418 |
+
|
419 |
+
def getQuestionSystemPrompt(self):
|
420 |
+
bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
|
421 |
+
2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
|
422 |
+
3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
|
423 |
+
4. Provide your REASONING behind choosing the QUESTION and the OPTIONS. Now provide the QUESTION and the OPTIONS. Along with each OPTION, provide the ads from ADS_DATA that you associated with it.
|
424 |
+
|
425 |
+
---------------------------------------
|
426 |
+
|
427 |
+
<Sample INPUT>
|
428 |
+
The Effects of Aging on Skin
|
429 |
+
|
430 |
+
<Sample ADS_DATA>
|
431 |
+
{"Cluster 1 Ads": {"Ad 1": "Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."}, "Cluster 2 Ads": {"Ad 2": "Stop Covering Your Wrinkles with Make Up - Do This Instead.", "Ad 3": "Living With Migraines? - Discover A Treatment Option. Learn about a type of prescription migraine treatment called CGRP receptor antagonists. Discover a range of resources that may help people dealing with migraines"}, "Cluster 3 Ads": {"Ad 4": "What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 5": "Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 6": "Treatment For CKD - Reduce Risk Of Progressing CKD. Ask About A Treatment That Can Help Reduce Your Risk Of Kidney Failure", "Ad 7": "Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."}]
|
432 |
+
|
433 |
+
<Expected json output>
|
434 |
+
{
|
435 |
+
"reasoning" : "Among the seven ads in **Sample ADS_DATA**, Ads 3 and 6 are irrelevant to the INPUT, so they should be discarded. Ad 1, 2 closely aligns with the user's intent. Ads 4, 5, and 7 are also relevant to INPUT. The question will be formed in a way to connect the PAGE content with the goals of these five relevant ads, making sure they appeal to both specific and general user interests, with the OPTIONS being the answer for QUESTION(it is ensured that no irrelevant options are formed)",
|
436 |
+
"question": "Interested in methods to combat aging skin?",
|
437 |
+
"options": {"1. Retinol Alternatives for Wrinkle Treatment." : ["Ad 1: Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."], "2. Reduce Wrinkles without Makeup.": ["Ad 2: Stop Covering Your Wrinkles with Make Up - Do This Instead."], "3. Information on Skin Diseases": ["Ad 3: What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 4: Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 5: Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."]}
|
438 |
+
}
|
439 |
+
-----------------------------------------------
|
440 |
+
|
441 |
+
<Sample INPUT>
|
442 |
+
Got A Rosemary Bush? Here’re 20 Brilliant & Unusual Ways To Use All That Rosemary
|
443 |
+
|
444 |
+
<Sample ADS_DATA>
|
445 |
+
<empty>
|
446 |
+
|
447 |
+
<Expected json output>
|
448 |
+
{
|
449 |
+
"reasoning" : "No ads available",
|
450 |
+
"question": "",
|
451 |
+
"options": []
|
452 |
+
}
|
453 |
+
-----------------------------------------------
|
454 |
+
|
455 |
+
The ADS_DATA provided to you is as follows:
|
456 |
+
"""
|
457 |
+
# old_system_prompt_additional_example = """
|
458 |
+
# -----------------------------------------------
|
459 |
+
# <Sample INPUT>
|
460 |
+
# 7 Signs and Symptoms of Magnesium Deficiency
|
461 |
+
|
462 |
+
# <Sample ADS_DATA>
|
463 |
+
# Ad 1: 4 Warning Signs Of Dementia - Fight Dementia and Memory Loss. 100% Natural Program To Prevent Cognitive Decline. Developed By Dr. Will Mitchell. Read The Reviews-Get a Special Offer. Doctor Recommended. High Quality Standards. 60-Day Refund.
|
464 |
+
# Ad 2: About Hyperkalemia - Learn About The Symptoms. High Potassium Can Be A Serious Condition. Learn More About Hyperkalemia Today.
|
465 |
+
# Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of People With Narcolepsy Are Believed to Have Cataplexy Symptoms. Learn More. Download the Doctor Discussion Guide to Have a Informed Conversation About Your Health.
|
466 |
+
|
467 |
+
# <Expected json output>
|
468 |
+
# {
|
469 |
+
# "reasoning" : "Given the input '7 Signs and Symptoms of Magnesium Deficiency,' it is evident that the user is looking for information specifically about magnesium deficiency. Ads 1, 2, and 3 discuss topics such as dementia, hyperkalemia, weak muscles, which are not related to magnesium deficiency in any way. Therefore, all the ads in the ADS_DATA are not suitable for the user's query and will be discarded.",
|
470 |
+
# "question": "No related ads available to form question and options.",
|
471 |
+
# "options": []
|
472 |
+
# }
|
473 |
+
# ------------------------------------------------
|
474 |
+
# """
|
475 |
+
return bestQuestionSystemPrompt
|
476 |
+
|
477 |
+
def getRelationSystemPrompt(self):
|
478 |
+
bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
|
479 |
+
|
480 |
+
---------------------------------------
|
481 |
+
|
482 |
+
**Sample INPUT***: What Causes Bright-Yellow Urine and Other Changes in Color?
|
483 |
+
|
484 |
+
Expected json output :
|
485 |
+
{
|
486 |
+
"reasoning" : "Given the user's search for 'What Causes Bright-Yellow Urine and Other Changes in Color?', it is evident that they are looking for information related to the causes and implications of changes in urine color. Therefore, ads that are related to urine color. However, none of the provided ads are related to urine. Ads 1, 2, and 3 are focused on chronic lymphocytic leukemia (CLL) treatment, Ad 4 addresses high blood pressure, and Ad 5 is about migraine treatment. Since none of these ads are not at all relevant to the urine, they can all be considered irrelevant to the user's intent.",
|
487 |
+
"classification": 0
|
488 |
+
}
|
489 |
+
------------------------------------------------
|
490 |
+
|
491 |
+
**Sample INPUT**: The Effects of Aging on Skin
|
492 |
+
|
493 |
+
Expected json output :
|
494 |
+
{
|
495 |
+
"reasoning" : "Given the user's search for 'The Effects of Aging on Skin,' it is clear that they are seeking information related to skin aging. Therefore, the ads that are relevant to skin effects should be considered. Ads 1 and 2 focus on wrinkle treatment and anti-aging solutions, making them pertinent to the user's intent. Ad 3 targets vitiligo and not general skin aging but it is related to skin effect. So it is also relevant. Ads 4 and 5 are about advanced lung cancer, which do not address the interest in skin. Ads 1 and 2, 3 are most relevant to the user's search. So ADS_DATA is relevant to INPUT. ",
|
496 |
+
"classification": 1
|
497 |
+
}
|
498 |
+
---------------------------------------
|
499 |
+
|
500 |
+
The ADS_DATA provided to you is as follows:
|
501 |
+
"""
|
502 |
+
return bestRelationSystemPrompt
|
503 |
+
|
504 |
+
|
505 |
+
# *********************** DB GENERATION ******************************
|
506 |
+
# df = pd.read_csv(data_file_path, sep="\t")
|
507 |
+
|
508 |
+
# --------------------------------
|
509 |
+
# WEB DATA PROCESSING
|
510 |
+
# from urllib.parse import urlparse
|
511 |
+
# import re
|
512 |
+
# def get_cleaned_url(url):
|
513 |
+
# path = urlparse(url).path.strip()
|
514 |
+
# cleaned_path = re.sub(r'[^a-zA-Z0-9\s-]', ' ', path).replace('/', '')
|
515 |
+
# cleaned_path = re.sub(r'[^a-zA-Z0-9\s]', ' ', path).replace('-', '')
|
516 |
+
# return cleaned_path.strip()
|
517 |
+
|
518 |
+
# df['cleaned_url'] = df['url'].map(get_cleaned_url)
|
519 |
+
# df.dropna(subset=['cleaned_url', 'url_content', 'url_title'], inplace=True)
|
520 |
+
# df['combined'] = df['cleaned_url'] + ". " + df['url_title'] + ". " + df['url_content']
|
521 |
+
# content = df["combined"].tolist()
|
522 |
+
# metadata = [
|
523 |
+
# {"title": row["url_title"], "url": row["url"]}
|
524 |
+
# for _, row in df.iterrows()
|
525 |
+
# ]
|
526 |
+
# ------------------------------
|
527 |
+
# ADS DATA PROCESSING
|
528 |
+
# # df.dropna(axis=0, how='any', inplace=True)
|
529 |
+
# df.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
|
530 |
+
# dfRPC = df[df['RPC'] > 0]
|
531 |
+
# dfRPC.dropna(how = 'any', inplace=True)
|
532 |
+
# dfCampaign = df[df['type'] == 'campaign']
|
533 |
+
# dfCampaign.fillna('', inplace=True)
|
534 |
+
# df = pd.concat([dfRPC, dfCampaign])
|
535 |
+
# df
|
536 |
+
|
537 |
+
# content = (df["ad_title"] + ". " + df["ad_desc"]).tolist()
|
538 |
+
# metadata = [
|
539 |
+
# {"publisher_url": row["publisher_url"], "keyword_term": row["keyword_term"], "ad_display_url": row["ad_display_url"], "revenue": row["revenue"], "ad_click_count": row["ad_click_count"], "RPC": row["RPC"], "Type": row["type"]}
|
540 |
+
# # {"revenue": row["revenue"], "ad_click_count": row["ad_click_count"]}
|
541 |
+
# for _, row in df.iterrows()
|
542 |
+
# ]
|
543 |
+
# --------------------------------
|
544 |
+
|
545 |
+
# faiss_db = FAISS_DB()
|
546 |
+
# db = faiss_db.createDBFromDocs(content, metadata)
|
547 |
+
# faiss_db.saveDB(db, '.')
|
548 |
+
|
549 |
+
# ************************************************************************
|
550 |
+
# PARALLELY CREATING DB - BACKUP FOR FUTURE USE
|
551 |
+
# import time
|
552 |
+
# import threading
|
553 |
+
# import os
|
554 |
+
# one_db_docs_size = 1000
|
555 |
+
# starting_i = 0
|
556 |
+
# parallel_processes = 3
|
557 |
+
# def split_list(lst, n):
|
558 |
+
# k, m = divmod(len(lst), n)
|
559 |
+
# return (lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
|
560 |
+
# def createDBForIndexes(inds):
|
561 |
+
# for i in inds:
|
562 |
+
# ctime = time.time()
|
563 |
+
# print(f"Processing {i}")
|
564 |
+
# if not os.path.exists(DB_FAISS_PATH + "/index_{int(i/one_db_docs_size)}.faiss"):
|
565 |
+
# db = FAISS.from_documents(split_docs[i:i+one_db_docs_size], embeddings_hf)
|
566 |
+
# db.save_local(DB_FAISS_PATH, index_name = f"index_{int(i/one_db_docs_size)}")
|
567 |
+
# ctime = time.time() - ctime
|
568 |
+
# print(f"{i})Time taken", ctime)
|
569 |
+
# indexes = split_list(range(starting_i, len(split_docs), one_db_docs_size), parallel_processes)
|
570 |
+
# threads = []
|
571 |
+
# for i, one_process_indexes in enumerate(indexes):
|
572 |
+
# thread = threading.Thread(target=createDBForIndexes, args=(one_process_indexes,))
|
573 |
+
# thread.start()
|
574 |
+
# threads.append(thread)
|
575 |
+
# for thread in threads:
|
576 |
+
# thread.join()
|
577 |
+
# print("All threads completed.")
|
578 |
+
# ************************************************************************
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
import pandas as pd
|
582 |
+
import gradio as gr
|
583 |
+
import random
|
584 |
+
|
585 |
+
vm = VARIABLE_MANAGER()
|
586 |
+
rag = vm.getRag()
|
587 |
+
ad_title_content = vm.GradioRagPreProcessing()
|
588 |
+
|
589 |
+
with gr.Blocks() as demo:
|
590 |
+
gr.Markdown("# RAG on ads data")
|
591 |
+
with gr.Row():
|
592 |
+
RelationPrompt = gr.Textbox(
|
593 |
+
vm.getRelationSystemPrompt(),
|
594 |
+
lines=1,
|
595 |
+
placeholder="Enter the relation system prompt for relation check",
|
596 |
+
label="Relation System prompt",
|
597 |
+
)
|
598 |
+
QuestionPrompt = gr.Textbox(
|
599 |
+
vm.getQuestionSystemPrompt(),
|
600 |
+
lines=1,
|
601 |
+
placeholder="Enter the question system prompt for question formulation",
|
602 |
+
label="Question System prompt",
|
603 |
+
)
|
604 |
+
page_information = gr.Textbox(
|
605 |
+
lines=1, placeholder="Enter the page information", label="Page Information"
|
606 |
+
)
|
607 |
+
threshold = gr.Number(
|
608 |
+
value=rag.db.default_threshold, label="Threshold", interactive=True
|
609 |
+
)
|
610 |
+
output = gr.Textbox(label="Output")
|
611 |
+
submit_btn = gr.Button("Submit")
|
612 |
+
|
613 |
+
submit_btn.click(
|
614 |
+
rag.getRagGradioResponse,
|
615 |
+
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
616 |
+
outputs=[output],
|
617 |
)
|
618 |
+
page_information.submit(
|
619 |
+
rag.getRagGradioResponse,
|
620 |
+
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
621 |
+
outputs=[output],
|
622 |
)
|
623 |
+
with gr.Accordion("Ad Titles", open=False):
|
624 |
+
ad_titles = gr.Markdown()
|
625 |
+
|
626 |
+
demo.load(
|
627 |
+
lambda: "<br>".join(
|
628 |
+
random.sample(
|
629 |
+
[str(ad_title) for ad_title in ad_title_content],
|
630 |
+
min(100, len(ad_title_content)),
|
631 |
+
)
|
632 |
+
),
|
633 |
+
None,
|
634 |
+
ad_titles,
|
635 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
gr.close_all()
|
638 |
+
demo.launch()
|