File size: 6,609 Bytes
9d051b5
1376e14
 
 
9d051b5
1376e14
 
2923422
 
9d051b5
1376e14
2923422
 
6eac492
55ff40c
9d051b5
2923422
 
 
6eac492
2923422
 
 
 
 
 
 
 
 
 
6eac492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1376e14
bb14bef
 
 
 
 
 
 
 
 
1376e14
bb14bef
 
 
 
 
 
 
 
 
 
1376e14
9d051b5
1376e14
 
 
 
 
 
 
 
9d051b5
1376e14
 
 
 
 
9d051b5
1376e14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d051b5
1376e14
 
 
9d051b5
1376e14
 
 
9d051b5
1376e14
 
 
9d051b5
1376e14
 
 
9d051b5
1376e14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from diffusers import DiffusionPipeline
import requests
from PIL import Image
from io import BytesIO
import onnxruntime as ort
from huggingface_hub import hf_hub_download

# Initialize models
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
anime_model = ort.InferenceSession(anime_model_path)
photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

# Load labels for the anime model
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
with open(labels_path, 'r') as f:
    anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]]  # Skip header

def preprocess_image(image):
    image = image.convert('RGB')
    image = image.resize((448, 448), Image.LANCZOS)
    image = np.array(image).astype(np.float32)
    image = image[:, :, ::-1]  # RGB -> BGR
    image = np.transpose(image, (2, 0, 1))  # HWC -> CHW
    image = image / 255.0
    return image[np.newaxis, ...]

def transcribe_image(image, image_type, transcriber, booru_tags=None):
    if image_type == "Anime":
        input_image = preprocess_image(image)
        input_name = anime_model.get_inputs()[0].name
        output_name = anime_model.get_outputs()[0].name
        probs = anime_model.run([output_name], {input_name: input_image})[0]
        
        # Get top 50 tags
        top_indices = probs[0].argsort()[-50:][::-1]
        tags = [anime_labels[i] for i in top_indices]
    else:
        prompt = "<MORE_DETAILED_CAPTION>"
        inputs = processor(text=prompt, images=image, return_tensors="pt")

        generated_ids = photo_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            do_sample=False,
            num_beams=3,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
        
        # Extract tags from parsed_answer
        tags = [obj['class'] for obj in parsed_answer]
    
    if booru_tags:
        tags = list(set(tags + booru_tags))
    
    return ", ".join(tags)


def get_booru_image(booru, image_id):
    if booru == "Gelbooru":
        url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
    elif booru == "Danbooru":
        url = f"https://danbooru.donmai.us/posts/{image_id}.json"
    elif booru == "rule34.xxx":
        url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
    else:
        raise ValueError("Unsupported booru")

    response = requests.get(url)
    data = response.json()

    # The exact structure of the response will vary depending on the booru
    # You'll need to adjust this part based on each booru's API
    image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url']
    tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split()

    img_response = requests.get(image_url)
    img = Image.open(BytesIO(img_response.content))

    return img, tags

def update_image(image_type, booru, image_id, uploaded_image):
    if image_type == "Anime" and booru != "Upload":
        image, booru_tags = get_booru_image(booru, image_id)
        return image, gr.update(visible=True), booru_tags
    elif uploaded_image is not None:
        return uploaded_image, gr.update(visible=True), None
    else:
        return None, gr.update(visible=False), None

def on_image_type_change(image_type):
    if image_type == "Anime":
        return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
    else:
        return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])

with gr.Blocks() as app:
    gr.Markdown("# Image Transcription App")
    
    with gr.Tab("Step 1: Image"):
        image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
        
        with gr.Column(visible=False) as anime_options:
            booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
            image_id = gr.Textbox(label="Image ID")
            get_image_btn = gr.Button("Get image")
        
        upload_btn = gr.UploadButton("Upload Image", visible=False)
        
        image_display = gr.Image(label="Image to transcribe", visible=False)
        booru_tags = gr.State(None)
        
        transcribe_btn = gr.Button("Transcribe", visible=False)
        transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
    
    with gr.Tab("Step 2: Transcribe"):
        transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
        transcribe_image_display = gr.Image(label="Image to transcribe")
        transcribe_btn_final = gr.Button("Transcribe")
        tags_output = gr.Textbox(label="Transcribed tags")
    
    image_type.change(on_image_type_change, inputs=[image_type], 
                      outputs=[anime_options, upload_btn, transcriber])
    
    get_image_btn.click(update_image, 
                        inputs=[image_type, booru, image_id, upload_btn], 
                        outputs=[image_display, transcribe_btn, booru_tags])
    
    upload_btn.upload(update_image, 
                      inputs=[image_type, booru, image_id, upload_btn], 
                      outputs=[image_display, transcribe_btn, booru_tags])
    
    def transcribe_and_update(image, image_type, transcriber, booru_tags):
        tags = transcribe_image(image, image_type, transcriber, booru_tags)
        return image, tags
    
    transcribe_btn.click(transcribe_and_update, 
                         inputs=[image_display, image_type, transcriber, booru_tags], 
                         outputs=[transcribe_image_display, tags_output])
    
    transcribe_with_tags_btn.click(transcribe_and_update, 
                                   inputs=[image_display, image_type, transcriber, booru_tags], 
                                   outputs=[transcribe_image_display, tags_output])
    
    transcribe_btn_final.click(transcribe_image, 
                               inputs=[transcribe_image_display, image_type, transcriber], 
                               outputs=[tags_output])

app.launch()