File size: 2,995 Bytes
84c806e
 
 
 
 
6525b03
 
 
84c806e
 
 
6525b03
 
84c806e
6525b03
 
84c806e
7b207f0
 
 
84c806e
 
 
 
 
 
 
48a1fa8
6525b03
 
84c806e
 
2e45025
 
 
 
 
 
 
6525b03
 
84c806e
 
 
6525b03
 
 
84c806e
 
 
 
48a1fa8
 
 
 
 
 
 
 
 
 
84c806e
 
 
 
 
48a1fa8
 
 
 
 
 
 
 
 
84c806e
48a1fa8
84c806e
48a1fa8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os

import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def split_image(im, num_rows=3, num_cols=3):
    im = np.array(im)
    row_size = im.shape[0] // num_rows
    col_size = im.shape[1] // num_cols
    tiles = [
        im[row : row + row_size, col : col + col_size]
        for row in range(0, num_rows * row_size, row_size)
        for col in range(0, num_cols * col_size, col_size)
    ]
    return tiles


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Patch-based Relevance Retrieval")
    st.markdown(
        """
        Given a piece of text, the CLIP model finds the part of an image that best explains the text.
        To try it out, you can

        1. Upload an image
        2. Explain a part of the image in text

        which will yield the most relevant image tile from a grid of the image. You can specify how
        granular you want to be with your search by specifying the number of rows and columns that
        make up the image grid.
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
    captions = st.text_input(
        "Enter query to find most relevant part of image ",
        value="이건 서울의 경복궁 사진이다.",
    )

    col1, col2 = st.beta_columns(2)
    with col1:
        num_rows = st.slider(
            "Number of rows", min_value=1, max_value=5, value=3, step=1
        )
    with col2:
        num_cols = st.slider(
            "Number of columns", min_value=1, max_value=5, value=3, step=1
        )

    if st.button("질문 (Query)"):
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            st.markdown("""---""")
            with st.spinner("Computing..."):
                image_data = (
                    query2
                    if query2 is not None
                    else requests.get(query1, stream=True).raw
                )
                image = Image.open(image_data)
                st.image(image)

                images = split_image(image, num_rows, num_cols)

                inputs = processor(
                    text=captions, images=images, return_tensors="jax", padding=True
                )
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                outputs = model(**inputs)
                probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
                for idx, prob in sorted(
                    enumerate(probs), key=lambda x: x[1], reverse=True
                ):
                    st.text(f"Score: {prob[0]:.3f}")
                    st.image(images[idx])