import gradio as gr from llm_inference import LLMInferenceNode import random title = """

Random Prompt Generator

[X gokaygokay] [Github gokayfem] [comfyui_dagthomas]

Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.

""" # Global variable to store selected prompt type selected_prompt_type = "Long" # Default value def create_interface(): llm_node = LLMInferenceNode() with gr.Blocks(theme='bethecloud/storj_theme') as demo: gr.HTML(title) with gr.Row(): with gr.Column(scale=2): with gr.Accordion("Basic Settings"): custom = gr.Textbox(label="Custom Input Prompt (optional)") with gr.Accordion("Prompt Generation Options", open=False): prompt_type = gr.Dropdown( choices=["Long", "Short", "Medium", "Long"], label="Prompt Type", value="Long", interactive=True ) # Function to update the selected prompt type def update_prompt_type(value): global selected_prompt_type selected_prompt_type = value print(f"Updated prompt type: {selected_prompt_type}") return value # Connect the update_prompt_type function to the prompt_type dropdown prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type]) with gr.Column(scale=2): generate_button = gr.Button("Generate Prompt") with gr.Accordion("Generated Prompt", open=True): output = gr.Textbox(label="Generated Prompt", lines=4, show_copy_button=True) text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True) with gr.Column(scale=2): with gr.Accordion("""LLM Prompt Generation""", open=False): long_talk = gr.Checkbox(label="Long Talk", value=True) compress = gr.Checkbox(label="Compress", value=True) compression_level = gr.Dropdown( choices=["soft", "medium", "hard"], label="Compression Level", value="hard" ) custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5) # LLM Provider Selection llm_provider = gr.Dropdown( choices=["Hugging Face", "Groq", "SambaNova"], label="LLM Provider", value="Hugging Face" ) api_key = gr.Textbox(label="API Key", type="password", visible=False) model = gr.Dropdown(label="Model", choices=[], value="") generate_text_button = gr.Button("Generate Prompt with LLM") text_output = gr.Textbox(label="Generated Text", lines=10, show_copy_button=True) # Initialize Models based on provider def update_model_choices(provider): provider_models = { "Hugging Face": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "another-model-hf"], "Groq": ["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"], "SambaNova": ["Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.1-8B-Instruct"], } models = provider_models.get(provider, []) return gr.Dropdown.update(choices=models, value=models[0] if models else "") def update_api_key_visibility(provider): return gr.update(visible=False) # No API key required for selected providers llm_provider.change(update_model_choices, inputs=[llm_provider], outputs=[model]) llm_provider.change(update_api_key_visibility, inputs=[llm_provider], outputs=[api_key]) # Generate Prompt Function def generate_prompt(prompt_type, custom_input): dynamic_seed = random.randint(0, 1000000) result = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input) return result generate_button.click( generate_prompt, inputs=[prompt_type, custom], outputs=[output] ) # Generate Text with LLM def generate_text_with_llm(output_prompt, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected): global selected_prompt_type poster = False # Set a default value or modify as needed result = llm_node.generate( input_text=output_prompt, long_talk=long_talk, compress=compress, compression_level=compression_level, poster=poster, # Added the missing 'poster' argument prompt_type=selected_prompt_type, custom_base_prompt=custom_base_prompt, provider=provider, api_key=api_key, model=model_selected ) selected_prompt_type = "Long" return result generate_text_button.click( generate_text_with_llm, inputs=[output, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model], outputs=[text_output], api_name="generate_text" ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(share=True)