Spaces:

File size: 4,295 Bytes
80e7e8e
 
 
 
 
 
 
0bd6138
80e7e8e
171e0c0
 
80e7e8e
 
 
 
 
 
 
f7b4cfc
80e7e8e
 
 
 
 
 
 
 
0bd6138
0381192
80e7e8e
 
 
 
0bd6138
 
 
 
 
 
 
 
80e7e8e
 
 
 
 
 
0bd6138
80e7e8e
 
 
 
 
 
 
 
 
 
 
 
9179591
0bd6138
80e7e8e
 
 
 
 
 
0bd6138
f8f5690
 
80e7e8e
 
0bd6138
 
80e7e8e
0bd6138
80e7e8e
 
 
 
 
 
0bd6138
 
 
 
80e7e8e
 
0bd6138
80e7e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from io import BytesIO

import gradio as gr
import grpc
from PIL import Image
from cachetools import LRUCache
from gradio.image_utils import crop_scale

from inference_pb2 import GuideAndRescaleRequest, GuideAndRescaleResponse
from inference_pb2_grpc import GuideAndRescaleServiceStub


def get_bytes(img):
    if img is None:
        return img

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return buffered.getvalue()


def bytes_to_image(image: bytes) -> Image.Image:
    image = Image.open(BytesIO(image))
    return image


def edit(editor, source_prompt, target_prompt, config, progress=gr.Progress(track_tqdm=True)):
    image = editor['composite']

    if not image or not source_prompt or not target_prompt:
        raise ValueError("Need to upload an image and enter init and edit prompts")

    width, height = image.size
    if width != height:
        size = min(width, height)
        image = crop_scale(image, size, size)

    if image.size != (512, 512):
        image = image.resize((512, 512), Image.Resampling.LANCZOS)

    image_bytes = get_bytes(image)
    with grpc.insecure_channel(os.environ['SERVER']) as channel:
        stub = GuideAndRescaleServiceStub(channel)

        output: GuideAndRescaleResponse = stub.swap(
            GuideAndRescaleRequest(image=image_bytes, source_prompt=source_prompt, target_prompt=target_prompt,
                                   config=config, use_cache=True)
        )

    output = bytes_to_image(output.image)
    return output


def get_demo():
    with gr.Blocks() as demo:
        gr.Markdown("## Guide-and-Rescale")
        gr.Markdown(
            '<div style="display: flex; align-items: center; gap: 10px;">'
            '<span>Official Guide-and-Rescale Gradio demo:</span>'
            '<a href="https://arxiv.org/abs/2409.01322"><img src="https://img.shields.io/badge/arXiv-2408.12345-b31b1b.svg" height=22.5></a>'
            '<a href="https://github.com/AIRI-Institute/Guide-and-Rescale"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
            '<a href="https://colab.research.google.com/drive/1noKOOcDBBL_m5_UqU15jBBqiM8piLZ1O?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
            '</div>'
        )
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    image = gr.ImageEditor(label="Image that you want to edit", type="pil", layers=False,
                                           interactive=True, crop_size="1:1", eraser=False, brush=False,
                                           image_mode='RGB')
                with gr.Row():
                    source_prompt = gr.Textbox(label="Init Prompt", info="Describs the content on the original image.")
                    target_prompt = gr.Textbox(label="Edit Prompt",
                                               info="Describs what is expected in the output image.")
                    config = gr.Radio(["non-stylisation", "stylisation"], value='non-stylisation',
                                      label="Type of Editing", info="Selects a config for editing.")
                with gr.Row():
                    btn = gr.Button("Edit image")
            with gr.Column():
                with gr.Row():
                    output = gr.Image(label="Result: edited image")

        gr.Examples(examples=[["input/1.png", 'A photo of a tiger', 'A photo of a lion', 'non-stylisation'],
                              ["input/zebra.jpeg", 'A photo of a zebra', 'A photo of a white horse', 'non-stylisation'],
                              ["input/13.png", 'A photo', 'Anime style face', 'stylisation']],
                    inputs=[image, source_prompt, target_prompt, config],
                    outputs=output)

        image.upload(inputs=[image], outputs=image)

        btn.click(fn=edit, inputs=[image, source_prompt, target_prompt, config], outputs=output)

        gr.Markdown('''To cite the paper by the authors
    ```
        TODO: add cite
    ```
        ''')
    return demo


if __name__ == '__main__':
    align_cache = LRUCache(maxsize=10)
    demo = get_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860)