temp / DV-AGENT /generate_plot.py
NEXAS's picture
Upload 22 files
182219d verified
raw
history blame
No virus
4.11 kB
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import pandas as pd
import os
import streamlit as st
#from langchain_community.llms import HuggingFaceHub
from typing import List
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
def read_first_3_rows():
dataset_path = "dataset.csv"
try:
df = pd.read_csv(dataset_path)
first_3_rows = df.head(3).to_string(index=False)
except FileNotFoundError:
first_3_rows = "Error: Dataset file not found."
return first_3_rows
def generate_plot(question):
dataset_first_3_rows = read_first_3_rows()
GENERATE_PLOT_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose.
First 3 rows of the dataset:"""
DATASET = f"{dataset_first_3_rows}"
GENERATE_PLOT_TEMPLATE_SUFIX = """
Question:
{question}
# comment Example for protein count of different products:
import altair as alt
import pandas as pd
import streamlit as st
# comment Read the dataset
df = pd.read_csv('dataset.csv')
# comment Calculate the protein count of different products
product_protein = df.groupby('name')['protein'].sum().reset_index()
# comment Create the chart
chart = alt.Chart(product_protein).mark_bar().encode(
x=alt.X('name:N', title='Product Name'),
y=alt.Y('protein:Q', title='Protein Count')
)
# comment Display the chart
st.altair_chart(chart, use_container_width=True)
"""
template = GENERATE_PLOT_TEMPLATE_PREFIX + DATASET + GENERATE_PLOT_TEMPLATE_SUFIX
prompt = PromptTemplate(template=template, input_variables=['question'])
llm_chain = LLMChain(prompt=prompt, llm=llm1)
response = llm_chain.predict(question=question)
return response
def retry_generate_plot(question, error_message, error_code):
dataset_first_3_rows = read_first_3_rows()
RETRY_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose.
Current code attempts to create a visualization of dataset.csv to meet the objective. but it has encounted the given error. provide a corrected code. if you are adding comments or explanations they should start with #.
#Example:
import altair as alt
import pandas as pd
import streamlit as st
# Read the dataset
df = pd.read_csv('dataset.csv')
# Calculate the total social media followers for each region
region_followers = df.groupby('Region of Focus')[['X (Twitter) Follower #', 'Facebook Follower #', 'Instagram Follower #', 'Threads Follower #', 'YouTube Subscriber #', 'TikTok Subscriber #']].sum().reset_index()
# Melt the dataframe to convert it into long format
region_followers = region_followers.melt(id_vars='Region of Focus', var_name='Social Media', value_name='Total Followers')
# Create the chart
chart = alt.Chart(region_followers).mark_bar().encode(
x=alt.X('Region of Focus:N', title='Region of Focus'),
y=alt.Y('Total Followers:Q', title='Total Followers'),
color=alt.Color('Social Media:N', title='Social Media')
)
# Display the chart
st.altair_chart(chart, use_container_width=True)
First 3 rows of the dataset:"""
DATASET = f"{dataset_first_3_rows}"
RETRY_TEMPLATE_SUFIX = """
Objective: {question}
Current Code:
{error_code}
Error Message:
{error_message}
Corrected Code:
"""
retry_template = RETRY_TEMPLATE_PREFIX + DATASET + RETRY_TEMPLATE_SUFIX
retry_prompt = PromptTemplate(template=retry_template, input_variables=["question", "error_message, error_code"])
llm_chain = LLMChain(prompt=retry_prompt, llm=llm1)
response = llm_chain.predict(question=question, error_message=error_message, error_code=error_code)
return response