import json from functools import partial from .utils import * from .vote_utils import ( upvote_last_response_t2s as upvote_last_response, downvote_last_response_t2s as downvote_last_response, flag_last_response_t2s as flag_last_response, ) from .inference import( sample_prompt, generate_t2s ) from constants import TEXT_PROMPT_PATH with open(TEXT_PROMPT_PATH, 'r') as f: prompt_list = json.load(f) def build_single_model_ui(models): notice_markdown = """ # 🏔️ Play with Image Generation Models {promotion} ## 🤖 Choose any model to generate """ model_list = models.get_t2s_models() gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel) gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=model_list, value=model_list[0] if len(model_list) > 0 else "", interactive=True, show_label=False ) with gr.Row(): with gr.Accordion("🔍 Expand to see all Arena players", open=False): model_description_md = get_model_description_md(model_list) gr.Markdown(model_description_md, elem_id="model_description_markdown") with gr.Row(): textbox = gr.Textbox( show_label=False, placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER", container=True, elem_id="input_box", ) sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0) send_btn = gr.Button(value="📤 Send", variant="primary", scale=0) with gr.Row(): normal = gr.Image(width=512, label = "Normal", show_copy_button=True) rgb = gr.Image(width=512, label = "RGB", show_copy_button=True,) with gr.Row(): clear_btn = gr.Button(value="🗑️ Clear", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) with gr.Row(elem_id="Geometry Quality"): geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False) geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False) geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False) with gr.Row(elem_id="Texture Quality"): text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False) text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False) text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False) with gr.Row(elem_id="Alignment Quality"): align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False) align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False) align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False) gr.Markdown(acknowledgment_md, elem_id="ack_markdown") state = gr.State() geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn] text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn] align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn] for btn_list in [geo_btn_list, text_btn_list, align_btn_list]: upvote_btn, downvote_btn, flag_btn = btn_list upvote_btn.click( upvote_last_response, [state, model_selector], [textbox] + btn_list ) downvote_btn.click( downvote_last_response, [state, model_selector], [textbox] + btn_list ) flag_btn.click( flag_last_response, [state, model_selector], [textbox] + btn_list ) sample_btn.click( sample_prompt, [state, model_selector, prompt_list], state + [textbox], api_name="sample_btn_single" ) textbox.submit( gen_func, [state, textbox, model_selector, prompt_list], [state, normal, rgb], api_name="submit_btn_single", show_progress = "full" ).then( enable_buttons, None, geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] ) send_btn.click( gen_func, [state, textbox, model_selector, prompt_list], [state, normal, rgb], api_name="send_btn_single", show_progress = "full" ).then( enable_buttons, None, geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] ) clear_btn.click( clear_history, None, [state, textbox, normal, rgb], api_name="clear_history_single", show_progress="full" ).then( disable_buttons, None, geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] ) regenerate_btn.click( gen_func, [state, textbox, model_selector, prompt_list], [state, normal, rgb], api_name="regenerate_btn_single", show_progress = "full" ).then( enable_buttons, None, geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] )