BrAD / app.py
aarbelle's picture
add min len to slider
03be69d
raw
history blame
3.08 kB
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
feat_dir = 'brad_feats'
domains = ['real', 'painting', 'clipart', 'sketch']
shots = '-1'
num_nn = 20
search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
if search_domain == 'all':
for d in domains:
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[d] = (src_data,src_nn_fit)
else:
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[search_domain] = (src_data,src_nn_fit)
dst_data_dict = {}
min_len = 1e10
for d in domains:
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
dst_data_dict[d] = pickle.load(fp)
min_len = min(min_len,(dst_data_dict[d][0]))
def query(query_index, query_domain):
dst_data = dst_data_dict[query_domain]
dst_img_path = os.path.join(data_root, dst_data[0][query_index])
img_paths = [dst_img_path]
q_cl = dst_img_path.split('/')[-2]
captions = [f'Query: {q_cl}'.title()]
for s_domain, s_data in src_data_dict.items():
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
top_n_labels = s_data[0][2][top_n_matches_ids][0]
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
img_paths += src_img_pths
for p in src_img_pths:
src_cl = p.split('/')[-2]
src_file = p.split('/')[-1]
captions.append(src_cl.title())
print(img_paths)
return tuple([p for p in img_paths])+ tuple(captions)
demo = gr.Blocks()
with demo:
gr.Markdown('## Select Query Domain: ')
domain_drop = gr.Dropdown(domains)
# domain_select_button = gr.Button("Select Domain")
slider = gr.Slider(0, min_len)
image_button = gr.Button("Run")
with gr.Row():
gr.Markdown('# Query Image: \t\t\t\t ')
gr.Markdown('\t')
gr.Markdown('\t')
gr.Markdown('\t')
with gr.Column():
src_cap = gr.Label()
src_img = gr.Image()
out_images = []
out_captions = []
for d in domains:
gr.Markdown(f'# {d.title()} Domain Images')
with gr.Row():
for _ in range(num_results_per_domain):
with gr.Column():
out_captions.append(gr.Label())
out_images.append(gr.Image())
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
demo.launch(share=True)