|
import os |
|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
import openai |
|
from llama_index.llms.openai import OpenAI |
|
|
|
import os |
|
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext, PropertyGraphIndex |
|
from llama_index.core.indices.property_graph import ( |
|
ImplicitPathExtractor, |
|
SimpleLLMPathExtractor, |
|
) |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.core.retrievers import BaseRetriever |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llmlingua import PromptCompressor |
|
from rouge_score import rouge_scorer |
|
from semantic_text_similarity.models import WebBertSimilarity |
|
import nest_asyncio |
|
|
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("Prompt Optimization for a Policy Bot") |
|
|
|
uploaded_files = st.file_uploader("Upload a Policy document in pdf format", type="pdf", accept_multiple_files=True) |
|
|
|
if uploaded_files: |
|
for uploaded_file in uploaded_files: |
|
with open(f"./data/{uploaded_file.name}", 'wb') as f: |
|
f.write(uploaded_file.getbuffer()) |
|
reader = SimpleDirectoryReader(input_files=[f"./data/{uploaded_file.name}"]) |
|
documents = reader.load_data() |
|
st.success("File uploaded...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
splitter = SentenceSplitter(chunk_size=256) |
|
nodes = splitter.get_nodes_from_documents(documents) |
|
storage_context = StorageContext.from_defaults() |
|
storage_context.docstore.add_documents(nodes) |
|
index = VectorStoreIndex(nodes=nodes, storage_context=storage_context) |
|
|
|
|
|
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10) |
|
vector_retriever = index.as_retriever(similarity_top_k=10) |
|
|
|
|
|
class HybridRetriever(BaseRetriever): |
|
def __init__(self, vector_retriever, bm25_retriever): |
|
self.vector_retriever = vector_retriever |
|
self.bm25_retriever = bm25_retriever |
|
super().__init__() |
|
|
|
def _retrieve(self, query, **kwargs): |
|
bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs) |
|
vector_nodes = self.vector_retriever.retrieve(query, **kwargs) |
|
all_nodes = [] |
|
node_ids = set() |
|
for n in bm25_nodes + vector_nodes: |
|
if n.node.node_id not in node_ids: |
|
all_nodes.append(n) |
|
node_ids.add(n.node.node_id) |
|
return all_nodes |
|
|
|
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever) |
|
|
|
|
|
model = "gpt-3.5-turbo" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_context(query): |
|
contexts = hybrid_retriever.retrieve(query) |
|
context_list = [n.get_content() for n in contexts] |
|
return context_list |
|
|
|
|
|
|
|
def res(prompt): |
|
|
|
response = openai.chat.completions.create( |
|
model=model, |
|
messages=[ |
|
{"role":"system", |
|
"content":"You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse" |
|
}, |
|
{"role": "user", |
|
"content": prompt, |
|
} |
|
] |
|
) |
|
|
|
return [response.usage.prompt_tokens, response.usage.completion_tokens, response.usage.total_tokens, response.choices[0].message.content] |
|
|
|
|
|
|
|
if "token_summary" not in st.session_state: |
|
st.session_state.token_summary = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Enter your query:"): |
|
st.success("Fetching info...") |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
|
|
context_list = get_context(prompt) |
|
context = " ".join(context_list) |
|
|
|
|
|
|
|
full_prompt = "\n\n".join([context + prompt]) |
|
orig_res = res(full_prompt) |
|
st.session_state.messages.append({"role": "assistant", "content": "Generating Original prompt response..."}) |
|
st.session_state.messages.append({"role": "assistant", "content": orig_res[3]}) |
|
st.success("Generating Original prompt response...") |
|
with st.chat_message("assistant"): |
|
st.markdown(orig_res[3]) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": "Generating Optimized prompt response..."}) |
|
st.success("Generating Optimized prompt response...") |
|
|
|
llm_lingua = PromptCompressor( |
|
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", |
|
use_llmlingua2=True, device_map="cpu" |
|
) |
|
|
|
def prompt_compression(context, rate=0.5): |
|
compressed_context = llm_lingua.compress_prompt( |
|
context, |
|
rate=rate, |
|
force_tokens=["!", ".", "?", "\n"], |
|
drop_consecutive=True, |
|
) |
|
return compressed_context |
|
compressed_context = prompt_compression(context) |
|
full_opt_prompt = "\n\n".join([compressed_context['compressed_prompt'] + prompt]) |
|
compressed_res = res(full_opt_prompt) |
|
st.session_state.messages.append({"role": "assistant", "content": compressed_res[3]}) |
|
with st.chat_message("assistant"): |
|
st.markdown(compressed_res[3]) |
|
|
|
|
|
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) |
|
scores = scorer.score(compressed_res[3],orig_res[3]) |
|
webert_model = WebBertSimilarity(device='cpu') |
|
similarity_score = webert_model.predict([(compressed_res[3], orig_res[3])])[0] / 5 * 100 |
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": "Token Length Summary..."}) |
|
st.success('Token Length Summary...') |
|
st.session_state.messages.append({"role": "assistant", "content": f"Original Prompt has {orig_res[0]} tokens"}) |
|
st.write(f"Original Prompt has {orig_res[0]} tokens") |
|
st.session_state.messages.append({"role": "assistant", "content": f"Optimized Prompt has {compressed_res[0]} tokens"}) |
|
st.write(f"Optimized Prompt has {compressed_res[0]} tokens") |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": "Comparing Original and Optimized Prompt Response..."}) |
|
st.success("Comparing Original and Optimized Prompt Response...") |
|
st.session_state.messages.append({"role": "assistant", "content": f"Rouge Score : {scores['rougeL'].fmeasure * 100}"}) |
|
st.write(f"Rouge Score : {scores['rougeL'].fmeasure * 100}") |
|
st.session_state.messages.append({"role": "assistant", "content": f"Semantic Text Similarity Score : {similarity_score}"}) |
|
st.write(f"Semantic Text Similarity Score : {similarity_score}") |
|
|
|
st.write(" ") |
|
|
|
|
|
origin_tokens = orig_res[0] |
|
compressed_tokens = compressed_res[0] |
|
gpt_saving = (origin_tokens - compressed_tokens) * 0.06 / 1000 |
|
claude_saving = (origin_tokens - compressed_tokens) * 0.015 / 1000 |
|
mistral_saving = (origin_tokens - compressed_tokens) * 0.004 / 1000 |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": f"The optimized prompt has ${gpt_saving:.4f} saved in GPT-4."}) |
|
st.success(f"The optimized prompt has ${gpt_saving:.4f} saved in GPT-4.") |
|
|
|
st.success("Downloading Optimized Prompt...") |
|
st.download_button(label = "Download Optimized Prompt", |
|
data = full_opt_prompt, file_name='./data/optimized_prompt.txt') |
|
|
|
|