gokaygokay's picture
buttons
00229d2
raw
history blame
6.87 kB
import gradio as gr
from llm_inference import LLMInferenceNode
import random
title = """<h1 align="center">Random Prompt Generator</h1>
<p><center>
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a>
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a>
<p align="center">Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.</p>
</center></p>
"""
# Global variable to store selected prompt type
selected_prompt_type = "Long" # Default value
def create_interface():
llm_node = LLMInferenceNode()
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
custom = gr.Textbox(label="Custom Input Prompt (optional)", lines=3)
prompt_type = gr.Dropdown(
choices=["Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"],
label="Prompt Type",
value="Long",
interactive=True
)
# Function to update the selected prompt type
def update_prompt_type(value):
global selected_prompt_type
selected_prompt_type = value
print(f"Updated prompt type: {selected_prompt_type}")
return value
# Connect the update_prompt_type function to the prompt_type dropdown
prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type])
with gr.Column(scale=2):
with gr.Accordion("LLM Prompt Generation", open=False):
long_talk = gr.Checkbox(label="Long Talk", value=True)
compress = gr.Checkbox(label="Compress", value=True)
compression_level = gr.Dropdown(
choices=["soft", "medium", "hard"],
label="Compression Level",
value="hard"
)
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
# LLM Provider Selection
llm_provider = gr.Dropdown(
choices=["Hugging Face", "Groq", "SambaNova"],
label="LLM Provider",
value="Groq"
)
api_key = gr.Textbox(label="API Key", type="password", visible=False)
model = gr.Dropdown(label="Model", choices=["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"], value="llama-3.1-70b-versatile")
with gr.Row():
# **Single Button for Generating Prompt and Text**
generate_button = gr.Button("Generate Random Prompt with LLM")
with gr.Row():
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
# Updated Models based on provider
def update_model_choices(provider):
provider_models = {
"Hugging Face": [
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3"
],
"Groq": [
"llama-3.1-70b-versatile",
"mixtral-8x7b-32768",
"gemma2-9b-it"
],
"SambaNova": [
"Meta-Llama-3.1-70B-Instruct",
"Meta-Llama-3.1-405B-Instruct",
"Meta-Llama-3.1-8B-Instruct"
],
}
models = provider_models.get(provider, [])
return gr.Dropdown(choices=models, value=models[0] if models else "")
def update_api_key_visibility(provider):
return gr.update(visible=False) # No API key required for selected providers
llm_provider.change(
update_model_choices,
inputs=[llm_provider],
outputs=[model]
)
llm_provider.change(
update_api_key_visibility,
inputs=[llm_provider],
outputs=[api_key]
)
# **Unified Function to Generate Prompt and Text**
def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected):
global selected_prompt_type # Declare as global
try:
# Step 1: Generate Prompt
dynamic_seed = random.randint(0, 1000000)
if custom_input and custom_input.strip():
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
print(f"Using Custom Input Prompt.")
else:
# Inform the system to create a random prompt based on the selected prompt_type
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, f"Create a random prompt based on the '{prompt_type}' type.")
print(f"No Custom Input Prompt provided. Generated prompt based on prompt_type: {prompt_type}")
print(f"Generated Prompt: {prompt}")
# Step 2: Generate Text with LLM
poster = False # Set a default value or modify as needed
result = llm_node.generate(
input_text=prompt,
long_talk=long_talk,
compress=compress,
compression_level=compression_level,
poster=poster, # Added the missing 'poster' argument
prompt_type=selected_prompt_type,
custom_base_prompt=custom_base_prompt,
provider=provider,
api_key=api_key,
model=model_selected
)
print(f"Generated Text: {result}")
# Reset selected_prompt_type if necessary
selected_prompt_type = "Long"
return result
except Exception as e:
print(f"An error occurred: {e}")
return f"Error occurred while processing the request: {str(e)}"
# **Connect the Unified Function to the Single Button**
generate_button.click(
generate_random_prompt_with_llm,
inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model],
outputs=[text_output],
api_name="generate_random_prompt_with_llm"
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)