import sys, os import gradio as gr ## if kgen not exist try: import kgen except: GH_TOKEN = os.getenv("GITHUB_TOKEN") git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TITPOP-KGen@titpop" ## call pip install os.system(f"pip install git+{git_url}") import re import random from time import time import torch from transformers import set_seed if sys.platform == "win32": #dev env in windows, @spaces.GPU will cause problem def GPU(func): return func else: from spaces import GPU import kgen.models as models import kgen.executor.titpop as titpop from kgen.formatter import seperate_tags, apply_format from kgen.generate import generate from diff import load_model, encode_prompts from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT sdxl_pipe = load_model() models.load_model( "KBlueLeaf/TITPOP-200M-dev", device="cuda", subfolder="dan-cc-coyo_epoch2", ) generate(max_new_tokens=4) DEFAULT_TAGS = """ 1girl, king halo (umamusume), umamusume, ningen mame, ciloranko, ogipote, misu kasumi, solo, leaning forward, sky, masterpiece, absurdres, sensitive, newest """.strip() DEFAULT_NL = """ An illustration of a girl """.strip() def format_time(timing): total = timing["total"] generate_pass = timing["generate_pass"] result = "" result += f""" ### Process Time | Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second| |-|-|-| """ if "generated_tokens" in timing: total_generated_tokens = timing["generated_tokens"] total_input_tokens = timing["input_tokens"] if "generated_tokens" in timing and "total_sampling" in timing: sampling_time = timing["total_sampling"] / 1000 process_time = timing["prompt_process"] / 1000 model_time = timing["total_eval"] / 1000 result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second| | Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second| | Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second| """ if "generated_tokens" in timing: result += f""" ### Processed Tokens: * {total_input_tokens:} Input Tokens * {total_generated_tokens:} Output Tokens """ return result @GPU @torch.no_grad() def generate( tags, nl_prompt, black_list, temp, output_format, target_length, top_p, min_p, top_k, seed, escape_brackets, ): default_format = DEFAULT_FORMAT[output_format] titpop.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()] generation_setting = { "seed": seed, "temperature": temp, "top_p": top_p, "min_p": min_p, "top_k": top_k, } inputs = seperate_tags(tags.split(",")) if nl_prompt: if "<|extended|>" in default_format: inputs["extended"] = nl_prompt elif "<|generated|>" in default_format: inputs["generated"] = nl_prompt input_prompt = apply_format(inputs, default_format) if escape_brackets: input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt) meta, operations, general, nl_prompt = titpop.parse_titpop_request( seperate_tags(tags.split(",")), nl_prompt, tag_length_target=target_length, generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt, ) t0 = time() for result, timing in titpop.titpop_runner_generator( meta, operations, general, nl_prompt, **generation_setting ): result = apply_format(result, default_format) if escape_brackets: result = re.sub(r"([()\[\]])", r"\\\1", result) timing["total"] = time() - t0 yield result, input_prompt, format_time(timing) @GPU @torch.no_grad() def generate_image( seed, prompt, prompt2, ): torch.cuda.empty_cache() set_seed(seed) prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT) ) result2 = sdxl_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_embeds2, negative_pooled_prompt_embeds=neg_pooled_embeds2, num_inference_steps=24, width=1024, height=1024, guidance_scale=6.0, ).images[0] yield result2, None prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT) ) set_seed(seed) result = sdxl_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_embeds2, negative_pooled_prompt_embeds=neg_pooled_embeds2, num_inference_steps=24, width=1024, height=1024, guidance_scale=6.0, ).images[0] torch.cuda.empty_cache() yield result2, result if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Accordion("Introduction and Instructions", open=False): gr.Markdown( """ ## TITPOP Demo ### What is this TITPOP is a tool to extend, generate, refine the input prompt for T2I models.
It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2) ### How to use this demo 1. Enter your tags(optional): put the desired tags into "danboru tags" box 2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box 3. Enter your black list(optional): put the desired black list into "black list" box 4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ... 4. Click "TITPOP" button: you will see refined prompt on "result" box 5. If you like the result, click "Generate Image From Result" button * You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt * The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False ### Why inference code is private? When will it be open sourced? 1. This model/tool is still under development, currently is early Alpha version. 2. I'm doing some research and projects based on this. 3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself. 4. Once the project/research are done, I will open source all these models/codes with Apache2 license. ### Notification **ITPOP is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model** """ ) with gr.Row(): with gr.Column(scale=5): with gr.Row(): with gr.Column(scale=3): tags_input = gr.TextArea( label="Danbooru Tags", lines=7, show_copy_button=True, interactive=True, value=DEFAULT_TAGS, placeholder="Enter danbooru tags here", ) nl_prompt_input = gr.Textbox( label="Natural Language Prompt", lines=7, show_copy_button=True, interactive=True, value=DEFAULT_NL, placeholder="Enter Natural Language Prompt here", ) black_list = gr.TextArea( label="Black List (seperated by comma)", lines=4, interactive=True, value="monochrome", placeholder="Enter tag/nl black list here", ) with gr.Column(scale=2): output_format = gr.Dropdown( label="Output Format", choices=list(DEFAULT_FORMAT.keys()), value="Both, tag first (recommend)" ) target_length = gr.Dropdown( label="Target Length", choices=["very_short", "short", "long", "very_long"], value="long", ) temp = gr.Slider( label="Temp", minimum=0.0, maximum=1.5, value=0.5, step=0.05, ) top_p = gr.Slider( label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05, ) min_p = gr.Slider( label="Min P", minimum=0.0, maximum=0.2, value=0.05, step=0.01, ) top_k = gr.Slider( label="Top K", minimum=0, maximum=120, value=60, step=1 ) with gr.Row(): seed = gr.Number( label="Seed", minimum=0, maximum=2147483647, value=20090220, step=1, ) escape_brackets = gr.Checkbox( label="Escape Brackets", value=False ) submit = gr.Button("TITPOP!", variant="primary") with gr.Accordion("Speed statstics", open=False): cost_time = gr.Markdown() with gr.Column(scale=5): result = gr.TextArea( label="Result", lines=8, show_copy_button=True, interactive=False ) input_prompt = gr.Textbox( label="Input Prompt", lines=1, interactive=False, visible=False ) gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False) with gr.Row(): with gr.Column(): img1 = gr.Image(label="Original Propmt", interactive=False) with gr.Column(): img2 = gr.Image(label="Generated Prompt", interactive=False) def generate_wrapper(*args): yield "", "", "", gr.update(interactive=False), for i in generate(*args): yield *i, gr.update(interactive=False) yield *i, gr.update(interactive=True) submit.click( generate_wrapper, [ tags_input, nl_prompt_input, black_list, temp, output_format, target_length, top_p, min_p, top_k, seed, escape_brackets, ], [ result, input_prompt, cost_time, gen_img, ], queue=True, ) def generate_image_wrapper(seed, result, input_prompt): for img1, img2 in generate_image(seed, result, input_prompt): yield img1, img2, gr.update(interactive=False) yield img1, img2, gr.update(interactive=True) gen_img.click( generate_image_wrapper, [seed, result, input_prompt], [img1, img2, submit], queue=True, ) gen_img.click( lambda *args: gr.update(interactive=False), None, [submit], queue=False, ) demo.launch()