File size: 2,569 Bytes
9d051b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc1e25
9d051b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from PIL import Image
import requests
from diffusers import StableDiffusionPipeline

# Load models using diffusers
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")

# Placeholder functions for the actual implementations
def check_anime_image(image):
    # Use SauceNAO or similar service to check if the image is anime
    # and fetch similar images and tags
    return False, [], []

def describe_image_general(image):
    # Use the general model to describe the image
    description = general_model(image)
    return description

def describe_image_anime(image):
    # Use the anime model to describe the image
    description = anime_model(image)
    return description

def merge_tags(tags1, tags2):
    # Merge tags, removing duplicates
    return list(set(tags1 + tags2))

# Gradio app functions
def process_image(image, mode):
    # Convert the image to a format suitable for the models
    image = image.resize((256, 256))

    if mode == "Anime":
        is_anime, similar_images, original_tags = check_anime_image(image)
        if is_anime:
            tags = describe_image_anime(image)
            return tags, original_tags
        else:
            return ["Not an anime image"], []
    else:
        tags = describe_image_general(image)
        return tags, []

def describe(image, mode):
    tags, original_tags = process_image(image, mode)
    return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))

def merge(tags, original_tags):
    merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
    return "\n".join(merged_tags)

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
        mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
    
    describe_button = gr.Button("Describe")
    merge_button = gr.Button("Merge Tags")
    
    with gr.TabGroup() as tab_group:
        with gr.TabItem("Described Tags"):
            described_tags = gr.TextArea(label="Described Tags")
        with gr.TabItem("Original Tags"):
            original_tags = gr.TextArea(label="Original Tags")
    
    merged_tags = gr.TextArea(label="Merged Tags")
    
    describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
    merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)

demo.launch()