import gradio as gr from PIL import Image import requests from diffusers import StableDiffusionPipeline # Load models using diffusers general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion") # Placeholder functions for the actual implementations def check_anime_image(image): # Use SauceNAO or similar service to check if the image is anime # and fetch similar images and tags return False, [], [] def describe_image_general(image): # Use the general model to describe the image description = general_model(image) return description def describe_image_anime(image): # Use the anime model to describe the image description = anime_model(image) return description def merge_tags(tags1, tags2): # Merge tags, removing duplicates return list(set(tags1 + tags2)) # Gradio app functions def process_image(image, mode): # Convert the image to a format suitable for the models image = image.resize((256, 256)) if mode == "Anime": is_anime, similar_images, original_tags = check_anime_image(image) if is_anime: tags = describe_image_anime(image) return tags, original_tags else: return ["Not an anime image"], [] else: tags = describe_image_general(image) return tags, [] def describe(image, mode): tags, original_tags = process_image(image, mode) return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags)) def merge(tags, original_tags): merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n")) return "\n".join(merged_tags) # Gradio interface with gr.Blocks() as demo: with gr.Row(): image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image") mode = gr.Dropdown(choices=["Anime", "General"], label="Mode") describe_button = gr.Button("Describe") merge_button = gr.Button("Merge Tags") with gr.TabGroup() as tab_group: with gr.TabItem("Described Tags"): described_tags = gr.TextArea(label="Described Tags") with gr.TabItem("Original Tags"): original_tags = gr.TextArea(label="Original Tags") merged_tags = gr.TextArea(label="Merged Tags") describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags]) merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags) demo.launch()