describe-test / app.py
aolko's picture
Update app.py
3cc1e25 verified
raw
history blame
No virus
2.57 kB
import gradio as gr
from PIL import Image
import requests
from diffusers import StableDiffusionPipeline
# Load models using diffusers
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
# Placeholder functions for the actual implementations
def check_anime_image(image):
# Use SauceNAO or similar service to check if the image is anime
# and fetch similar images and tags
return False, [], []
def describe_image_general(image):
# Use the general model to describe the image
description = general_model(image)
return description
def describe_image_anime(image):
# Use the anime model to describe the image
description = anime_model(image)
return description
def merge_tags(tags1, tags2):
# Merge tags, removing duplicates
return list(set(tags1 + tags2))
# Gradio app functions
def process_image(image, mode):
# Convert the image to a format suitable for the models
image = image.resize((256, 256))
if mode == "Anime":
is_anime, similar_images, original_tags = check_anime_image(image)
if is_anime:
tags = describe_image_anime(image)
return tags, original_tags
else:
return ["Not an anime image"], []
else:
tags = describe_image_general(image)
return tags, []
def describe(image, mode):
tags, original_tags = process_image(image, mode)
return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))
def merge(tags, original_tags):
merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
return "\n".join(merged_tags)
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
describe_button = gr.Button("Describe")
merge_button = gr.Button("Merge Tags")
with gr.TabGroup() as tab_group:
with gr.TabItem("Described Tags"):
described_tags = gr.TextArea(label="Described Tags")
with gr.TabItem("Original Tags"):
original_tags = gr.TextArea(label="Original Tags")
merged_tags = gr.TextArea(label="Merged Tags")
describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)
demo.launch()