from langchain.prompts import PromptTemplate from langchain.chains import LLMChain import pandas as pd import os import streamlit as st #from langchain_community.llms import HuggingFaceHub from typing import List from langchain_groq import ChatGroq from dotenv import load_dotenv load_dotenv() groq_api_key = os.getenv("GROQ_API_KEY") llm1 = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768") def read_first_3_rows(): dataset_path = "dataset.csv" try: df = pd.read_csv(dataset_path) first_3_rows = df.head(3).to_string(index=False) except FileNotFoundError: first_3_rows = "Error: Dataset file not found." return first_3_rows def generate_plot(question): dataset_first_3_rows = read_first_3_rows() GENERATE_PLOT_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose. First 3 rows of the dataset:""" DATASET = f"{dataset_first_3_rows}" GENERATE_PLOT_TEMPLATE_SUFIX = """ Question: {question} # comment Example for protein count of different products: import altair as alt import pandas as pd import streamlit as st # comment Read the dataset df = pd.read_csv('dataset.csv') # comment Calculate the protein count of different products product_protein = df.groupby('name')['protein'].sum().reset_index() # comment Create the chart chart = alt.Chart(product_protein).mark_bar().encode( x=alt.X('name:N', title='Product Name'), y=alt.Y('protein:Q', title='Protein Count') ) # comment Display the chart st.altair_chart(chart, use_container_width=True) """ template = GENERATE_PLOT_TEMPLATE_PREFIX + DATASET + GENERATE_PLOT_TEMPLATE_SUFIX prompt = PromptTemplate(template=template, input_variables=['question']) llm_chain = LLMChain(prompt=prompt, llm=llm1) response = llm_chain.predict(question=question) return response def retry_generate_plot(question, error_message, error_code): dataset_first_3_rows = read_first_3_rows() RETRY_TEMPLATE_PREFIX = """You are a high skilled visualization assistant that can modify a provided visualization code based on a set of instructions. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose. Current code attempts to create a visualization of dataset.csv to meet the objective. but it has encounted the given error. provide a corrected code. if you are adding comments or explanations they should start with #. #Example: import altair as alt import pandas as pd import streamlit as st # Read the dataset df = pd.read_csv('dataset.csv') # Calculate the total social media followers for each region region_followers = df.groupby('Region of Focus')[['X (Twitter) Follower #', 'Facebook Follower #', 'Instagram Follower #', 'Threads Follower #', 'YouTube Subscriber #', 'TikTok Subscriber #']].sum().reset_index() # Melt the dataframe to convert it into long format region_followers = region_followers.melt(id_vars='Region of Focus', var_name='Social Media', value_name='Total Followers') # Create the chart chart = alt.Chart(region_followers).mark_bar().encode( x=alt.X('Region of Focus:N', title='Region of Focus'), y=alt.Y('Total Followers:Q', title='Total Followers'), color=alt.Color('Social Media:N', title='Social Media') ) # Display the chart st.altair_chart(chart, use_container_width=True) First 3 rows of the dataset:""" DATASET = f"{dataset_first_3_rows}" RETRY_TEMPLATE_SUFIX = """ Objective: {question} Current Code: {error_code} Error Message: {error_message} Corrected Code: """ retry_template = RETRY_TEMPLATE_PREFIX + DATASET + RETRY_TEMPLATE_SUFIX retry_prompt = PromptTemplate(template=retry_template, input_variables=["question", "error_message, error_code"]) llm_chain = LLMChain(prompt=retry_prompt, llm=llm1) response = llm_chain.predict(question=question, error_message=error_message, error_code=error_code) return response