Ads_Rag / rag.py
Rajat.bans
Updated the code no bugs now
afcfa9b
raw
history blame contribute delete
No virus
39.1 kB
from sklearn.cluster import KMeans, SpectralClustering
from scipy.spatial.distance import euclidean
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
import re
import numpy as np
from openai import OpenAI
import json
from itertools import count
import time
import os
from langchain_community.embeddings import HuggingFaceEmbeddings
from dotenv import load_dotenv
from typing import List, Tuple, Dict, Any
from numpy import ndarray
from langchain_core.documents import Document
import gradio as gr
import random
import pandas as pd
class CLUSTERING:
def cluster_embeddings(
self,
embeddings: ndarray,
clustering_algo: str,
no_of_clusters: int,
no_of_points: int,
) -> List[List[int]]:
"""
Clusters embeddings using the specified clustering algorithm and returns the indices of points in each cluster.
Parameters:
embeddings (ndarray): The input embeddings to cluster.
clustering_algo (str): The clustering algorithm to use ("kmeans-cc", "kmeans-sp", or "spectral").
no_of_clusters (int): The number of clusters to form.
no_of_points (int): The maximum number of points to include in each cluster.
Returns:
List[List[int]]: A list of clusters, each containing the indices of the points in that cluster.
"""
# If the chosen algorithm is KMeans-based
if clustering_algo in {"kmeans-cc", "kmeans-sp"}:
# Initialize KMeans with the specified number of clusters
kmeans = KMeans(
n_clusters=min(no_of_clusters, len(embeddings)),
random_state=42,
n_init="auto",
)
kmeans.fit(embeddings)
cluster_centers = kmeans.cluster_centers_
labels = kmeans.labels_
# If the algorithm is "kmeans-cc", we cluster based on center-to-point distances
if clustering_algo == "kmeans-cc":
clusters_indices = [[] for _ in range(no_of_clusters)]
for i, embedding in enumerate(embeddings):
cluster_idx = labels[i]
center = cluster_centers[cluster_idx]
dist = euclidean(embedding, center)
clusters_indices[cluster_idx].append((i, dist))
for i in range(no_of_clusters):
clusters_indices[i].sort(key=lambda x: x[1])
else:
# If the algorithm is "kmeans-sp", we simply collect points until reaching the desired count
clusters_indices = [[] for _ in range(no_of_clusters)]
for i, label in enumerate(labels):
if len(clusters_indices[label]) < no_of_points:
clusters_indices[label].append(i)
if all(
len(cluster) == no_of_points for cluster in clusters_indices
):
break
# If the chosen algorithm is Spectral Clustering
elif clustering_algo == "spectral":
spectral_clustering = SpectralClustering(
n_clusters=no_of_clusters, affinity="nearest_neighbors", random_state=42
)
labels = spectral_clustering.fit_predict(embeddings)
clusters_indices = [[] for _ in range(no_of_clusters)]
for i, label in enumerate(labels):
if len(clusters_indices[label]) < no_of_points:
clusters_indices[label].append(i)
if all(len(cluster) == no_of_points for cluster in clusters_indices):
break
# Return the indices of points in each cluster, limited to the specified number of points per cluster
return [
[cluster_point[0] for cluster_point in clusters_indices[i][:no_of_points]]
for i in range(no_of_clusters)
]
class VECTOR_DB:
def __init__(
self,
default_threshold: float,
number_of_ads_to_fetch_from_db: int,
clustering_algo: str,
no_of_clusters: int,
no_of_ads_in_each_cluster: int,
DB_FAISS_PATH: str,
embeddings_hf: HuggingFaceEmbeddings,
) -> None:
"""
Initialize the VECTOR_DB with the specified parameters and load the FAISS database.
Parameters:
default_threshold (float): The default similarity threshold for filtering documents.
number_of_ads_to_fetch_from_db (int): The number of ads to retrieve from the database.
clustering_algo (str): The clustering algorithm to use ('kmeans-cc', 'kmeans-sp', 'spectral_clustering').
no_of_clusters (int): The number of clusters to form.
no_of_ads_in_each_cluster (int): The number of ads in each cluster.
DB_FAISS_PATH (str): The path to the FAISS database.
embeddings_hf (HuggingFaceEmbeddings): The embeddings model to use.
"""
self.default_threshold = default_threshold
self.number_of_ads_to_fetch_from_db = number_of_ads_to_fetch_from_db
self.clustering_algo = clustering_algo
self.no_of_clusters = no_of_clusters
self.no_of_ads_in_each_cluster = no_of_ads_in_each_cluster
self.embeddings_hf = embeddings_hf
self.db = FAISS.load_local(
DB_FAISS_PATH, embeddings_hf, allow_dangerous_deserialization=True
)
def queryVectorDB(
self, page_information: str, threshold: float = None
) -> Tuple[List[List[Tuple[Document, float]]], float]:
"""
Query the vector database and cluster the retrieved documents.
Parameters:
page_information (str): The information about the page to query.
threshold (Union[float, None]): The similarity threshold for filtering documents. If None, the default threshold is used.
Returns:
Tuple[List[List[Tuple]], float]: A tuple containing a list of clustered documents and the best similarity value.
"""
def remove_html_tags(text: str) -> str:
"""Remove HTML tags from a string."""
clean = re.compile("<.*?>")
return re.sub(clean, "", text)
# Use default threshold if none provided
if threshold is None:
threshold = self.default_threshold
# Retrieve documents from the database that meet the threshold criteria
retreived_documents = [
doc
for doc in self.db.similarity_search_with_score(
page_information, k=self.number_of_ads_to_fetch_from_db
)
if doc[1] < threshold
]
# Remove HTML tags from the retrieved documents' content
for i in range(len(retreived_documents)):
retreived_documents[i][0].page_content = remove_html_tags(
retreived_documents[i][0].page_content
)
# If documents are retrieved, cluster them
if len(retreived_documents):
embeddings = np.array(
self.embeddings_hf.embed_documents(
[doc[0].page_content for doc in retreived_documents]
)
)
clustered_indices = CLUSTERING().cluster_embeddings(
embeddings,
self.clustering_algo,
self.no_of_clusters,
self.no_of_ads_in_each_cluster,
)
# Group documents by their cluster indices
documents_clusters = [
[retreived_documents[ind] for ind in cluster_indices]
for cluster_indices in clustered_indices
]
# Get the best similarity score
best_value = retreived_documents[0][1]
return documents_clusters, best_value
# Return an empty list and default similarity score if no documents meet the threshold
return [], 1.0
class FAISS_DB:
def __init__(self) -> None:
"""Initialize the FAISS_DB class."""
pass
def createDocs(
self,
content: List[str],
metadata: List[Dict[str, Any]],
CHUNK_SIZE: int = 2048,
CHUNK_OVERLAP: int = 512,
) -> List[Document]:
"""
Split the provided content into chunks with metadata.
Parameters:
content (List[str]): The content to split.
metadata (List[Dict[str, Any]]): The metadata associated with the content.
CHUNK_SIZE (int): The size of each chunk. Default is 2048.
CHUNK_OVERLAP (int): The overlap between chunks. Default is 512.
Returns:
List[Document]: The split documents with metadata.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
)
split_docs = text_splitter.create_documents(content, metadata)
print(f"Documents are split into {len(split_docs)} passages")
return split_docs
def createDBFromDocs(
self, split_docs: List[Document], embeddings_model: HuggingFaceEmbeddings
) -> FAISS:
"""
Create a FAISS database from the provided documents and embeddings model.
Parameters:
split_docs (List[Document]): The split documents.
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
Returns:
FAISS: The created FAISS database.
"""
db = FAISS.from_documents(split_docs, embeddings_model)
return db
def createAndSaveDBInChunks(
self,
split_docs: List[Document],
embeddings_model: HuggingFaceEmbeddings,
DB_FAISS_PATH: str,
chunk_size: int = 1000,
) -> None:
"""
Create and save the FAISS database in chunks.
Parameters:
split_docs (List[Document]): The split documents.
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
DB_FAISS_PATH (str): The path to save the FAISS database.
chunk_size (int): The size of each chunk. Default is 1000.
"""
one_db_docs_size = chunk_size
starting_i = 0
for i in range(starting_i, len(split_docs), one_db_docs_size):
ctime = time.time()
print(i, end=", ")
db = FAISS.from_documents(
split_docs[i : i + one_db_docs_size], embeddings_model
)
self.saveDB(db, DB_FAISS_PATH, f"index_{int(i/one_db_docs_size)}")
ctime = time.time() - ctime
print(
"Time remaining",
(ctime / one_db_docs_size * (len(split_docs) - i)) / 60,
"minutes",
)
def mergeSecondDbIntoFirst(self, db1: FAISS, db2: FAISS) -> None:
"""
Merge the second FAISS database into the first one.
Parameters:
db1 (FAISS): The first FAISS database.
db2 (FAISS): The second FAISS database to merge into the first.
"""
db1.merge_from(db2)
def saveDB(self, db: FAISS, DB_FAISS_PATH: str, index_name: str = "index") -> None:
"""
Save the FAISS database locally.
Parameters:
db (FAISS): The FAISS database to save.
DB_FAISS_PATH (str): The path to save the FAISS database.
index_name (str): The name of the index. Default is "index".
"""
db.save_local(DB_FAISS_PATH, index_name)
def readingAndCombining(
self,
DB_FAISS_PATH: str,
embeddings_hf: HuggingFaceEmbeddings,
index_name_1: str,
index_name_2: str,
) -> FAISS:
"""
Load and combine two FAISS databases.
Parameters:
DB_FAISS_PATH (str): The path to the FAISS database.
embeddings_hf (HuggingFaceEmbeddings): The embeddings model to use.
index_name_1 (str): The name of the first index.
index_name_2 (str): The name of the second index.
Returns:
FAISS: The combined FAISS database.
"""
db1 = FAISS.load_local(
DB_FAISS_PATH,
embeddings_hf,
allow_dangerous_deserialization=True,
index_name=index_name_1,
)
db2 = FAISS.load_local(
DB_FAISS_PATH,
embeddings_hf,
allow_dangerous_deserialization=True,
index_name=index_name_2,
)
db1.merge_from(db2)
# db1.index
# db1.docstore.search(target_id)
# len(db0.index_to_docstore_id)
return db1
def combineChunksDbs(
self, DB_FAISS_PATH: str, embeddings_hf: HuggingFaceEmbeddings
) -> FAISS:
"""
Combine chunked FAISS databases into a single database.
Parameters:
DB_FAISS_PATH (str): The path to the FAISS database.
embeddings_hf (HuggingFaceEmbeddings): The embeddings model to use.
Returns:
FAISS: The combined FAISS database.
"""
files = os.listdir(DB_FAISS_PATH)
ind = 0
for fl in files:
if fl.endswith(".faiss"):
cv = fl[6:-6]
ind = max(ind, int(cv))
all_dbs: List[FAISS] = []
for i in range(0, ind + 1, 2):
print(i)
db1 = FAISS.load_local(
DB_FAISS_PATH,
embeddings_hf,
allow_dangerous_deserialization=True,
index_name=f"index_{i}",
)
all_dbs.append(db1)
if os.path.exists(DB_FAISS_PATH + f"/index_{i+1}.faiss"):
db2 = FAISS.load_local(
DB_FAISS_PATH,
embeddings_hf,
allow_dangerous_deserialization=True,
index_name=f"index_{i+1}",
)
all_dbs.append(db2)
while len(all_dbs) != 1:
processed_dbs = []
print("ITERATION ----------->")
for i in range(0, len(all_dbs), 2):
db1 = all_dbs[i]
print(
f"For {i} before length is ",
len(db1.index_to_docstore_id),
end=", ",
)
if i + 1 != len(all_dbs):
db2 = all_dbs[i + 1]
self.mergeSecondDbIntoFirst(db1, db2)
processed_dbs.append(db1)
print(f"After length is ", len(db1.index_to_docstore_id))
all_dbs = processed_dbs
return all_dbs[0]
class ADS_RAG:
def __init__(
self,
db: VECTOR_DB,
qa_model_name: str,
relation_check_best_value_thresh: float,
bestRelationSystemPrompt: str,
bestQuestionSystemPrompt: str,
) -> None:
"""Initialize the ADS_RAG class with the given parameters."""
self.client = OpenAI()
self.db = db
self.qa_model_name = qa_model_name
self.relation_check_best_value_thresh = relation_check_best_value_thresh
self.bestRelationSystemPrompt = bestRelationSystemPrompt
self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
def callOpenAiApi(
self, messages: List[Dict[str, str]]
) -> Tuple[Dict[str, Any], int]:
"""
Call the OpenAI API with the given messages and return the response.
Parameters:
messages (List[Dict[str, str]]): The messages to send to the OpenAI API.
Returns:
Tuple[Dict[str, Any], int]: The response from the OpenAI API and the number of tokens used.
"""
while True:
try:
response = self.client.chat.completions.create(
model=self.qa_model_name,
messages=messages,
temperature=0,
seed=42,
max_tokens=1200,
response_format={"type": "json_object"},
)
tokens_used = response.usage.total_tokens
answer = json.loads(response.choices[0].message.content)
return answer, tokens_used
except Exception as e:
print(response.choices[0].message.content)
print("Error-: ", e)
print("Trying Again")
def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
self,
page_information: str,
adsData: str,
relationSystemPrompt: str,
questionSystemPrompt: str,
bestRetreivedAdValue: float,
) -> Dict[str, Any]:
"""
Get the best question based on page information and ads data.
Parameters:
page_information (str): The information about the page.
adsData (str): The data about the ads.
relationSystemPrompt (str): The system prompt for relation checking.
questionSystemPrompt (str): The system prompt for question generation.
bestRetreivedAdValue (float): The best retrieved ad value.
Returns:
Dict[str, Any]: The relation and question answers along with token usage information.
"""
if adsData == "":
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
{"reasoning": "", "question": "", "options": []},
0,
)
relation_answer = {"reasoning": "", "classification": 1}
question_answer = {"reasoning": "", "question": "", "options": []}
tokens_used_relation = 0
tokens_used_question = 0
if bestRetreivedAdValue > self.relation_check_best_value_thresh:
relation_answer, tokens_used_relation = self.callOpenAiApi(
[
{
"role": "system",
"content": relationSystemPrompt + adsData,
}
]
+ [
{
"role": "user",
"content": page_information + "\nThe JSON response: ",
}
]
)
else:
relation_answer["reasoning"] = (
"First retrieved document value less than threshold so no need to check relation"
)
if relation_answer["classification"] != 0:
question_answer, tokens_used_question = self.callOpenAiApi(
[
{
"role": "system",
"content": questionSystemPrompt + adsData,
}
]
+ [
{
"role": "user",
"content": page_information + "\nThe JSON response: ",
}
]
)
return {
"relation_answer": relation_answer,
"tokens_used_relation": tokens_used_relation,
"question_answer": question_answer,
"tokens_used_question": tokens_used_question,
}
def convertDocumentsClustersToStringForApiCall(
self, documents_clusters: List[List[Tuple[Document, float]]]
) -> str:
"""
Convert document clusters to a string format suitable for API calls.
Parameters:
documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
Returns:
str: The document clusters converted to a string.
"""
key_counter = count(1)
res = json.dumps(
{
f"Option {i+1} Ads": {
f"Ad {next(key_counter)}": document[0].page_content
for j, document in enumerate(documents_cluster)
}
for i, documents_cluster in enumerate(documents_clusters)
}
)
return res
def getRagResponse(
self,
page_information: str,
threshold: float = None,
RelationPrompt: str = None,
QuestionPrompt: str = None,
) -> Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]:
"""
Get the RAG response based on the page information and optional prompts.
Parameters:
page_information (str): The information about the page.
threshold (float): The threshold for querying the database. Default is None.
RelationPrompt (str): The prompt for relation checking. Default is None.
QuestionPrompt (str): The prompt for question generation. Default is None.
Returns:
Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]: The RAG response and the document clusters.
"""
curr_relation_prompt = self.bestRelationSystemPrompt
if RelationPrompt is not None and len(RelationPrompt):
curr_relation_prompt = RelationPrompt
curr_question_prompt = self.bestQuestionSystemPrompt
if QuestionPrompt is not None and len(QuestionPrompt):
curr_question_prompt = QuestionPrompt
documents_clusters, best_value = self.db.queryVectorDB(
page_information, threshold
)
answer = self.getBestQuestionOnTheBasisOfPageInformationAndAdsData(
page_information,
self.convertDocumentsClustersToStringForApiCall(documents_clusters),
curr_relation_prompt,
curr_question_prompt,
best_value,
)
return answer, documents_clusters
def changeDocumentsToPrintableString(
self, documents_clusters: List[List[Tuple[Document, float]]]
) -> str:
"""
Convert document clusters to a printable string format.
Parameters:
documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
Returns:
str: The document clusters converted to a printable string.
"""
res = ""
i = 0
for ind, documents_cluster in enumerate(documents_clusters):
res += f"Option {ind+1} Ads-:\n"
for document in documents_cluster:
i += 1
res += f"[Ad {i}] Content: {document[0].page_content}\nRevenue: {document[0].metadata['revenue']}\nAd Click Count: {document[0].metadata['ad_click_count']}\nValue: {document[1]}\n"
res += "\n"
return res
def changeResponseToPrintableString(
self, response: Dict[str, Any], task: str
) -> str:
"""
Convert the response to a printable string format.
Parameters:
response (Dict[str, Any]): The response to convert.
task (str): The task type ('relation' or 'question').
Returns:
str: The response converted to a printable string.
"""
if task == "relation":
return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
for option in response["options"]:
res += f"{option}\n"
for ad in response["options"][option]:
res += f"{ad}\n"
res += "\n"
return res
def logResult(
self,
curr_relation_prompt: str,
curr_question_prompt: str,
page_information: str,
answer: Dict[str, Any],
) -> None:
"""
Log the result of the RAG response.
Parameters:
curr_relation_prompt (str): The current relation prompt.
curr_question_prompt (str): The current question prompt.
page_information (str): The information about the page.
answer (Dict[str, Any]): The RAG response.
"""
print(
"**************************************************************************************************\n",
# curr_relation_prompt,
# curr_question_prompt,
page_information + "\n",
json.dumps(answer, indent=4),
"\n************************************************************************************************\n\n",
)
def getRagGradioResponse(
self,
page_information: str,
RelationPrompt: str,
QuestionPrompt: str,
threshold: float,
) -> str:
"""
Get the RAG response in a format suitable for Gradio.
Parameters:
page_information (str): The information about the page.
RelationPrompt (str): The prompt for relation checking.
QuestionPrompt (str): The prompt for question generation.
threshold (float): The threshold for querying the database.
Returns:
str: The full response formatted for Gradio.
"""
# Get the RAG response and document clusters
answer, documents_clusters = self.getRagResponse(
page_information, threshold, RelationPrompt, QuestionPrompt
)
# Log the result
self.logResult(RelationPrompt, QuestionPrompt, page_information, answer)
# Convert documents and responses to printable strings
docs_info = self.changeDocumentsToPrintableString(documents_clusters)
relation_answer_string = self.changeResponseToPrintableString(
answer["relation_answer"], "relation"
)
question_answer_string = self.changeResponseToPrintableString(
answer["question_answer"], "question"
)
# Get token usage information
question_tokens = answer["tokens_used_question"]
relation_tokens = answer["tokens_used_relation"]
# Format the full response
full_response = (
f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n "
f"Question answer:\n {question_answer_string}\n\n"
f"**RETRIEVED DOCUMENTS CLUSTERS**:\n{docs_info}\n\n"
f"**TOKENS USED**:\nQuestion api call: {question_tokens}\n"
f"Relation api call: {relation_tokens}"
)
return full_response
class Helper:
def __init__(self, DB_FAISS_PATH: str) -> None:
"""Initialize the Helper class and set environment variables."""
load_dotenv(override=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.DB_FAISS_PATH = DB_FAISS_PATH
def getRag(self) -> ADS_RAG:
"""
Create and return an instance of the ADS_RAG class.
Returns:
ADS_RAG: An instance of the ADS_RAG class.
"""
# Initialize embeddings using HuggingFace
embeddings_hf = HuggingFaceEmbeddings(
model_name="BAAI/bge-m3"
) # "sentence-transformers/all-mpnet-base-v2"
# Create a VECTOR_DB instance
vector_db = VECTOR_DB(
default_threshold=0.75,
number_of_ads_to_fetch_from_db=50,
clustering_algo="kmeans-cc",
no_of_clusters=3,
no_of_ads_in_each_cluster=6,
DB_FAISS_PATH=self.DB_FAISS_PATH,
embeddings_hf=embeddings_hf,
)
return ADS_RAG(
db=vector_db,
qa_model_name="gpt-3.5-turbo",
relation_check_best_value_thresh=0.6,
bestRelationSystemPrompt=self.getRelationSystemPrompt(),
bestQuestionSystemPrompt=self.getQuestionSystemPrompt(),
)
def getQuestionSystemPrompt(self) -> str:
"""
Return the system prompt for question generation.
Returns:
str: The question system prompt.
"""
bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
4. Provide your REASONING behind choosing the QUESTION and the OPTIONS. Now provide the QUESTION and the OPTIONS. Along with each OPTION, provide the ads from ADS_DATA that you associated with it.
---------------------------------------
<Sample INPUT>
The Effects of Aging on Skin
<Sample ADS_DATA>
{"Cluster 1 Ads": {"Ad 1": "Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."}, "Cluster 2 Ads": {"Ad 2": "Stop Covering Your Wrinkles with Make Up - Do This Instead.", "Ad 3": "Living With Migraines? - Discover A Treatment Option. Learn about a type of prescription migraine treatment called CGRP receptor antagonists. Discover a range of resources that may help people dealing with migraines"}, "Cluster 3 Ads": {"Ad 4": "What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 5": "Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 6": "Treatment For CKD - Reduce Risk Of Progressing CKD. Ask About A Treatment That Can Help Reduce Your Risk Of Kidney Failure", "Ad 7": "Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."}]
<Expected json output>
{
"reasoning" : "Among the seven ads in **Sample ADS_DATA**, Ads 3 and 6 are irrelevant to the INPUT, so they should be discarded. Ad 1, 2 closely aligns with the user's intent. Ads 4, 5, and 7 are also relevant to INPUT. The question will be formed in a way to connect the PAGE content with the goals of these five relevant ads, making sure they appeal to both specific and general user interests, with the OPTIONS being the answer for QUESTION(it is ensured that no irrelevant options are formed)",
"question": "Interested in methods to combat aging skin?",
"options": {"1. Retinol Alternatives for Wrinkle Treatment." : ["Ad 1: Forget Retinol, Use This Household Item To Fill In Wrinkles - Celebrities Are Ditching Pricey Facelifts For This."], "2. Reduce Wrinkles without Makeup.": ["Ad 2: Stop Covering Your Wrinkles with Make Up - Do This Instead."], "3. Information on Skin Diseases": ["Ad 3: What is Advanced Skin Cancer? - Find Disease Information Here.Find Facts About Advanced Skin Cancer and a Potential Treatment Option.", "Ad 4: Learn About Advanced Melanoma - Find Disease Information Here.Find Facts About Advanced Melanoma and a Potential Treatment Option.", "Ad 5: Are You Living With Vitiligo? - For Patients & Caregivers.Discover An FDA-Approved Topical Cream That May Help With Nonsegmental Vitiligo Repigmentation. Learn About A Copay Savings Card For Eligible Patients With Vitiligo."]}
}
-----------------------------------------------
<Sample INPUT>
Got A Rosemary Bush? Here’re 20 Brilliant & Unusual Ways To Use All That Rosemary
<Sample ADS_DATA>
<empty>
<Expected json output>
{
"reasoning" : "No ads available",
"question": "",
"options": []
}
-----------------------------------------------
The ADS_DATA provided to you is as follows:
"""
# old_system_prompt_additional_example = """
# -----------------------------------------------
# <Sample INPUT>
# 7 Signs and Symptoms of Magnesium Deficiency
# <Sample ADS_DATA>
# Ad 1: 4 Warning Signs Of Dementia - Fight Dementia and Memory Loss. 100% Natural Program To Prevent Cognitive Decline. Developed By Dr. Will Mitchell. Read The Reviews-Get a Special Offer. Doctor Recommended. High Quality Standards. 60-Day Refund.
# Ad 2: About Hyperkalemia - Learn About The Symptoms. High Potassium Can Be A Serious Condition. Learn More About Hyperkalemia Today.
# Ad 3: Weak or Paralyzed Muscles? - A Common Symptom of Cataplexy. About 70% of People With Narcolepsy Are Believed to Have Cataplexy Symptoms. Learn More. Download the Doctor Discussion Guide to Have a Informed Conversation About Your Health.
# <Expected json output>
# {
# "reasoning" : "Given the input '7 Signs and Symptoms of Magnesium Deficiency,' it is evident that the user is looking for information specifically about magnesium deficiency. Ads 1, 2, and 3 discuss topics such as dementia, hyperkalemia, weak muscles, which are not related to magnesium deficiency in any way. Therefore, all the ads in the ADS_DATA are not suitable for the user's query and will be discarded.",
# "question": "No related ads available to form question and options.",
# "options": []
# }
# ------------------------------------------------
# """
return bestQuestionSystemPrompt
def getRelationSystemPrompt(self) -> str:
"""
Return the system prompt for relation checking.
Returns:
str: The relation system prompt.
"""
bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
---------------------------------------
**Sample INPUT***: What Causes Bright-Yellow Urine and Other Changes in Color?
Expected json output :
{
"reasoning" : "Given the user's search for 'What Causes Bright-Yellow Urine and Other Changes in Color?', it is evident that they are looking for information related to the causes and implications of changes in urine color. Therefore, ads that are related to urine color. However, none of the provided ads are related to urine. Ads 1, 2, and 3 are focused on chronic lymphocytic leukemia (CLL) treatment, Ad 4 addresses high blood pressure, and Ad 5 is about migraine treatment. Since none of these ads are not at all relevant to the urine, they can all be considered irrelevant to the user's intent.",
"classification": 0
}
------------------------------------------------
**Sample INPUT**: The Effects of Aging on Skin
Expected json output :
{
"reasoning" : "Given the user's search for 'The Effects of Aging on Skin,' it is clear that they are seeking information related to skin aging. Therefore, the ads that are relevant to skin effects should be considered. Ads 1 and 2 focus on wrinkle treatment and anti-aging solutions, making them pertinent to the user's intent. Ad 3 targets vitiligo and not general skin aging but it is related to skin effect. So it is also relevant. Ads 4 and 5 are about advanced lung cancer, which do not address the interest in skin. Ads 1 and 2, 3 are most relevant to the user's search. So ADS_DATA is relevant to INPUT. ",
"classification": 1
}
---------------------------------------
The ADS_DATA provided to you is as follows:
"""
return bestRelationSystemPrompt
class RAGGradioApp:
def __init__(self, helper: Helper) -> None:
"""
Initialize the RAGGradioApp with an instance of ADS_RAG and Helper.
Args:
helper (Helper): An instance of Helper for configuration and prompts.
"""
self.my_rag = helper.getRag()
self.relationSystemPrompt = helper.getRelationSystemPrompt()
self.questionSystempPrompt = helper.getQuestionSystemPrompt()
def launch(self, ad_title_content: List) -> None:
"""
Construct and launch the Gradio interface for RAG functionality.
"""
with gr.Blocks() as demo:
gr.Markdown("# RAG on ads data")
with gr.Row():
RelationPrompt = gr.Textbox(
self.relationSystemPrompt,
lines=1,
placeholder="Enter the relation system prompt for relation check",
label="Relation System prompt",
)
# Textbox for Question System prompt
QuestionPrompt = gr.Textbox(
self.questionSystempPrompt,
lines=1,
placeholder="Enter the question system prompt for question formulation",
label="Question System prompt",
)
# Textbox for Page Information input
page_information = gr.Textbox(
lines=1,
placeholder="Enter the page information",
label="Page Information",
)
# Number input for Threshold
threshold = gr.Number(
value=self.my_rag.db.default_threshold, label="Threshold", interactive=True
)
# Textbox for displaying output
output = gr.Textbox(label="Output")
# Button for submitting the form
submit_btn = gr.Button("Submit")
# Define behavior on button click
submit_btn.click(
self.my_rag.getRagGradioResponse,
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
outputs=[output],
)
# Define behavior on form submission by pressing enter
page_information.submit(
self.my_rag.getRagGradioResponse,
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
outputs=[output],
)
# Accordion to display Ad Titles
with gr.Accordion("Ad Titles", open=False):
ad_titles = gr.Markdown()
# Load ad titles into the accordion
demo.load(
lambda: "<br>".join(
random.sample(
[str(ad_title) for ad_title in ad_title_content],
min(100, len(ad_title_content)),
)
),
None,
ad_titles,
)
gr.close_all()
demo.launch()
if __name__ == "__main__":
helper = Helper(
# "./vectorstore/db_faiss_ads_Jun_facty_activebeat_Health_dupRemoved0.85"
"./vectorstore/db_faiss_ads_20May_20Jun_webmd_healthline_Health_dupRemoved0.8"
)
data = pd.read_csv(
# "./data/149_adclick_Jun_facty_activeBeat_Health_dupRemoved0.85_campaign.tsv",
"./data/142_adclick_20May_20Jun_webmd_healthline_Health_dupRemoved0.8_someAdsCampaign.tsv",
sep="\t",
)
# data.dropna(axis=0, how="any", inplace=True)
ad_title_content = list(
data.drop_duplicates(subset=["ad_title", "ad_desc"])["ad_title"].values
)
RAGGradioApp(helper).launch(ad_title_content)