Spaces:

File size: 3,735 Bytes
80e7e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from io import BytesIO

import gradio as gr
import grpc
from PIL import Image
from cachetools import LRUCache
import hashlib

from protos.inference_pb2 import GuideAndRescaleRequest, GuideAndRescaleResponse
from protos.inference_pb2_grpc import GuideAndRescaleServiceStub


def get_bytes(img):
    if img is None:
        return img

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return buffered.getvalue()


def bytes_to_image(image: bytes) -> Image.Image:
    image = Image.open(BytesIO(image))
    return image


def resize(img):
    if img.size != (512, 512):
        img = img.resize((512, 512), Image.Resampling.LANCZOS)

    return img


def edit(image, source_prompt, target_prompt, config, progress=gr.Progress(track_tqdm=True)):
    if not image or not source_prompt or not target_prompt:
        raise ValueError("Need to upload an image and enter init and edit prompts")

    image_bytes = get_bytes(image)
    os.environ['SERVER'] = "0.0.0.0:50052"
    with grpc.insecure_channel(os.environ['SERVER']) as channel:
        stub = GuideAndRescaleServiceStub(channel)

        output: GuideAndRescaleResponse = stub.swap(
            GuideAndRescaleRequest(image=image_bytes, source_prompt=source_prompt, target_prompt=target_prompt,
                            config=config, use_cache=True)
        )

    output = bytes_to_image(output.image)
    return output


def get_demo():
    with gr.Blocks() as demo:
        gr.Markdown("## Guide-and-Rescale")
        gr.Markdown(
            '<div style="display: flex; align-items: center; gap: 10px;">'
            '<span>Official Guide-and-Rescale Gradio demo:</span>'
             '<a href="https://github.com/AIRI-Institute/Guide-and-Rescale"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
            '<a href="https://colab.research.google.com/drive/1noKOOcDBBL_m5_UqU15jBBqiM8piLZ1O?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
            '</div>'
        )
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    image = gr.Image(label="Image that you want to edit", type="pil")
                with gr.Row():
                    source_prompt = gr.Textbox(label="Init Prompt", info="Describs the content on the original image.")
                    target_prompt = gr.Textbox(label="Edit Prompt", info="Describs what is expected in the output image.")
                    config = gr.Radio(["non-stylisation", "stylisation"], value='non-stylisation',
                                        label="Type of Editing", info="Selects a config for editing.")
                with gr.Row():
                    btn = gr.Button("Edit image")
            with gr.Column():
                with gr.Row():
                    output = gr.Image(label="Result: edited image")

        gr.Examples(examples=[["input/1.png", 'A photo of a tiger', 'A photo of a lion', 'non-stylisation'], ["input/zebra.jpeg", 'A photo of a zebra', 'A photo of a white horse', 'non-stylisation'], ["input/13.png", 'A photo', 'Anime style face', 'stylisation']], inputs=[image, source_prompt, target_prompt, config],
                    outputs=output)

        image.upload(fn=resize, inputs=[image], outputs=image)

        btn.click(fn=edit, inputs=[image, source_prompt, target_prompt, config], outputs=output)

        gr.Markdown('''To cite the paper by the authors
    ```
        TODO: add cite
    ```
        ''')
    return demo


if __name__ == '__main__':
    align_cache = LRUCache(maxsize=10)
    demo = get_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860)