import gradio as gr import torch from transformers import AutoProcessor, AutoModelForCasualLM from diffusers import DiffusionPipeline import requests from PIL import Image from io import BytesIO import onnxruntime as ort from huggingface_hub import hf_hub_download # Initialize models anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx") anime_model = ort.InferenceSession(anime_model_path) photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) # Load labels for the anime model labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv") with open(labels_path, 'r') as f: anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header def preprocess_image(image): image = image.convert('RGB') image = image.resize((448, 448), Image.LANCZOS) image = np.array(image).astype(np.float32) image = image[:, :, ::-1] # RGB -> BGR image = np.transpose(image, (2, 0, 1)) # HWC -> CHW image = image / 255.0 return image[np.newaxis, ...] def transcribe_image(image, image_type, transcriber, booru_tags=None): if image_type == "Anime": input_image = preprocess_image(image) input_name = anime_model.get_inputs()[0].name output_name = anime_model.get_outputs()[0].name probs = anime_model.run([output_name], {input_name: input_image})[0] # Get top 50 tags top_indices = probs[0].argsort()[-50:][::-1] tags = [anime_labels[i] for i in top_indices] else: prompt = "" inputs = processor(images=image, text=prompt, return_tensors="pt") generated_ids = photo_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] tags = generated_text # Use generated text as the description return ", ".join(tags) def get_booru_image(booru, image_id): if booru == "Gelbooru": url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" elif booru == "Danbooru": url = f"https://danbooru.donmai.us/posts/{image_id}.json" elif booru == "rule34.xxx": url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" else: raise ValueError("Unsupported booru") response = requests.get(url) data = response.json() # The exact structure of the response will vary depending on the booru # You'll need to adjust this part based on each booru's API image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url'] tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split() img_response = requests.get(image_url) img = Image.open(BytesIO(img_response.content)) return img, tags def update_image(image_type, booru, image_id, uploaded_image): if image_type == "Anime" and booru != "Upload": image, booru_tags = get_booru_image(booru, image_id) return image, gr.update(visible=True), booru_tags elif uploaded_image is not None: return uploaded_image, gr.update(visible=True), None else: return None, gr.update(visible=False), None def on_image_type_change(image_type): if image_type == "Anime": return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) else: return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) with gr.Blocks() as app: gr.Markdown("# Image Transcription App") with gr.Tab("Step 1: Image"): image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") with gr.Column(visible=False) as anime_options: booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") image_id = gr.Textbox(label="Image ID") get_image_btn = gr.Button("Get image") upload_btn = gr.UploadButton("Upload Image", visible=False) image_display = gr.Image(label="Image to transcribe", visible=False) booru_tags = gr.State(None) transcribe_btn = gr.Button("Transcribe", visible=False) transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) with gr.Tab("Step 2: Transcribe"): transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") transcribe_image_display = gr.Image(label="Image to transcribe") transcribe_btn_final = gr.Button("Transcribe") tags_output = gr.Textbox(label="Transcribed tags") image_type.change(on_image_type_change, inputs=[image_type], outputs=[anime_options, upload_btn, transcriber]) get_image_btn.click(update_image, inputs=[image_type, booru, image_id, upload_btn], outputs=[image_display, transcribe_btn, booru_tags]) upload_btn.upload(update_image, inputs=[image_type, booru, image_id, upload_btn], outputs=[image_display, transcribe_btn, booru_tags]) def transcribe_and_update(image, image_type, transcriber, booru_tags): tags = transcribe_image(image, image_type, transcriber, booru_tags) return image, tags transcribe_btn.click(transcribe_and_update, inputs=[image_display, image_type, transcriber, booru_tags], outputs=[transcribe_image_display, tags_output]) transcribe_with_tags_btn.click(transcribe_and_update, inputs=[image_display, image_type, transcriber, booru_tags], outputs=[transcribe_image_display, tags_output]) transcribe_btn_final.click(transcribe_image, inputs=[transcribe_image_display, image_type, transcriber], outputs=[tags_output]) app.launch()