File size: 2,175 Bytes
793740b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36309df
793740b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36309df
793740b
36309df
 
 
 
 
 
 
 
 
 
793740b
 
 
 
6e223a1
 
 
 
 
 
 
36309df
 
 
 
 
 
 
 
793740b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr

from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline

"""
auth_token = os.environ.get("API_TOKEN") or True

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)

transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      transforms.Resize((512, 512)),
])


def predict(dict, prompt=""):
    init_image = dict["image"].convert("RGB").resize((512, 512))
    mask = dict["mask"].convert("RGB").resize((512, 512))
    output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
    return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

"""

def add_text(text, image, image_process_mode, request: gr.Request):
    text = text[:1536]  # Hard cut-off
    if image is not None:
        print(image)
        text = text[:1200]  # Hard cut-off for images
        if "<image>" not in text:
            # text = '<Image><image></Image>' + text
            text = text + "\n<image>"
        text = (text, image, image_process_mode)
        print(text)


image_blocks = gr.Blocks()
with image_blocks as demo:
    image_process_mode = gr.Radio(
                    ["Crop", "Resize", "Pad", "Default"],
                    value="Default",
                    label="Preprocess for non-square image",
                    visible=False,
                )
    textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
    imagebox = gr.Image(type="pil")
    submit_btn = gr.Button(value="Send", variant="primary", interactive=False)
    submit_btn.click(
            add_text,
            [textbox, imagebox, image_process_mode],
            [],
        )
    


image_blocks.launch()