File size: 2,700 Bytes
61c7634
e15dae8
c9911aa
17bb1f6
 
fae45ed
61c7634
236866f
c34d9ea
4622c8d
 
8465f52
 
3ce1a28
34700d7
78f266b
34700d7
d30d4ce
de07d62
 
921054e
ec11b9a
6f3fb83
62635cf
17bb1f6
5c5bd98
6f3fb83
189516e
de07d62
af0b55f
 
 
f4d85d5
 
6115563
31fec50
34700d7
 
de07d62
925d019
62635cf
189516e
58bc1a3
1377bb8
f4d85d5
 
 
 
 
ed60eed
5c5bd98
8465f52
 
 
f4d85d5
04856c4
f4d85d5
 
04856c4
f4d85d5
 
8465f52
f4d85d5
 
 
ec11b9a
f4d85d5
cf1f1df
5e188c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from turtle import color, onclick
import streamlit as st
from PIL import Image, ImageOps
import glob
import json
import requests
import random
import io

random.seed(10)

if 'show' not in st.session_state:
    st.session_state.show = False

if 'example_idx' not in st.session_state:
    st.session_state.example_idx = 0

st.set_page_config(layout="wide")
st.markdown("**This is a demo of the *ImageCoDe* benchmark. What is the task? You are given a description and you have to pick the image it describes, out of 10 images total.**")
st.markdown("**If you click the Sample button, you will get a new text and images. More details of ImageCoDe can be found in our ACL 2022 paper.**")

col1, col2 = st.columns(2)

prefix = 'https://raw.githubusercontent.com/BennoKrojer/imagecode-val-set/main/image-sets-val/'
set2ids = json.load(open('set2ids.json', 'r'))
descriptions = json.load(open('valid_list.json', 'r'))

#example_idx = int(col1.number_input('Sample an example (description + corresponding images) from the validation set', value=0, min_value=0, max_value=len(descriptions)-1))
if col1.button('Sample a description + 10 images from the validation set'):
    st.session_state.example_idx += 1
#    st.session_state.example_idx = random.randint(0, len(descriptions)-1)

img_set, true_idx, descr = descriptions[st.session_state.example_idx]
true_idx = int(true_idx)
images = [prefix+'/'+img_set+'/'+i for i in set2ids[img_set]]
img_urls = images.copy()
index = int(col2.number_input('Image Index from 0 to 9', value=0, min_value=0, max_value=9))

if col1.button('Toggle to reveal/hide the correct image, try to guess yourself before giving up!)'):
    st.session_state.show = not st.session_state.show

col1.markdown(f'**Description for {img_set}**:')
col1.markdown(f'**{descr}**') 

big_img = images[index]
img = Image.open(io.BytesIO(requests.get(images[index], stream=True).content))
img_width, img_height = img.size
smaller = min(img_width, img_height)
images[index]= ImageOps.expand(img,border=smaller//18,fill='blue')

caps = list(range(10))
cap = str(index)

if st.session_state.show:
    caps[true_idx] = f'{true_idx} (TARGET IMAGE)'
    img = Image.open(io.BytesIO(requests.get(img_urls[true_idx], stream=True).content))
    img_width, img_height = img.size
    smaller = min(img_width, img_height)
    images[true_idx] = ImageOps.expand(img,border=smaller//8,fill='green')
    if true_idx == index:
        cap = f'{true_idx} (TARGET IMAGE)'
else:
    caps[true_idx] = f'{true_idx}'
    if true_idx == index:
        cap = f'{true_idx}'

col1.image(big_img, use_column_width=True, caption=cap)
col2.image(images, width=175, caption=caps)
col1.markdown(f'{st.session_state.example_idx}')