import os import time import gradio as gr from gradio.themes import Size, GoogleFont import sys import pandas as pd import webbrowser from marqo import Client from PIL import Image import urllib.request from PIL import Image import requests import matplotlib.pyplot as plt from pathlib import Path from datetime import datetime import time import webbrowser from transformers import CLIPProcessor, CLIPModel # model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") # processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") static_dir = Path('./static') static_dir.mkdir(parents=True, exist_ok=True) # client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") # client = Client() # index_name = "new_look_expanded_dresses" # device = "cpu" class Client_Settings(): def __init__(self): self.client = Client() self.index_name = "new_look_expanded_dresses" self.device = "cpu" def conn_to_local(self): self.client = Client() def conn_to_server(self, url): self.client = Client(url) def set_index_name(self, new_index_name): self.index_name = new_index_name def set_device(self, new_device): self.device = new_device client_obj = Client_Settings() # client_obj.conn_to_local() client_obj.conn_to_server("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") client_obj.set_index_name("new_look_expanded_dresses") client_obj.set_device("cuda") # Create custom Color objects for our primary, secondary, and neutral colors primary_color = gr.themes.colors.slate secondary_color = gr.themes.colors.rose neutral_color = gr.themes.colors.stone # Assuming black for text # Set the sizes spacing_size = gr.themes.sizes.spacing_md radius_size = gr.themes.sizes.radius_md text_size = gr.themes.sizes.text_md # Set the fonts font = GoogleFont("Source Sans Pro") font_mono = GoogleFont("IBM Plex Mono") # Create the theme theme = gr.themes.Base( primary_hue=primary_color, secondary_hue=secondary_color, neutral_hue=neutral_color, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono ) def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame: return dataset[dataset[column_name].str.contains(search_term)] def dedup_by(dataset, column_name) -> pd.DataFrame: return dataset.drop_duplicates(subset=[column_name]) def drop_secondary_images(dataset) -> pd.DataFrame: dataset.image = dataset.primary_image return dataset.drop_duplicates(subset=['primary_image']) def dataset_to_gallery(dataset: pd.DataFrame, _score=None) -> list: # convert to list of tuples new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy() if type(_score) != type(pd.Series()): new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) else: new_df['name_code_combined'] = (_score).map('{:,.4f}'.format).astype(str) + '@@' + new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) final_df = new_df[['image', 'name_code_combined']] items = final_df.to_records(index=False).tolist() return items def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame: df = dataset.sort_values(by=['best_seller_score'], ascending=False) return df[start_index:end_index] # def return_page(page, dataset: pd.DataFrame): # start_index = page * result_per_page # end_index = (page + 1) * result_per_page # df = get_items_from_dataset(start_index, end_index, dataset) # return dataset_to_gallery(dedup_by(df, 'colour_code')) def start_page(num_results=50): result = client_obj.client.index(client_obj.index_name).search("Dress", score_modifiers = { "add_to_score": [{"field_name": "best_seller_score","weight": 5}], }, searchable_attributes=['image'], device=client_obj.device, limit=num_results) imgs = [r for r in result["hits"]] return return_results_page(imgs) def return_results_page(results_list: list): df = pd.DataFrame(results_list) df_unique = drop_secondary_images(df) return dataset_to_gallery(df_unique, df_unique['_score']) def return_item(combined) -> list: colour_code = combined.split("@@")[2] result = client_obj.client.index(client_obj.index_name).search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device=client_obj.device) imgs = [r for r in result["hits"]] df = pd.DataFrame(imgs) return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"] def return_specific_item(combined) -> list: _id = combined.split("@@")[3] result = client_obj.client.index(client_obj.index_name).search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device=client_obj.device) imgs = [r for r in result["hits"]] print(imgs) df = pd.DataFrame(imgs) return dataset_to_gallery(df)[0][0] ### Load local def load_image(image_input): image_input.save("../../../Documents/images/img_path.jpg") os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"') def search_images(query, best_seller_score_weight): result = client_obj.client.index(client_obj.index_name).search(query, score_modifiers = { "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], }, searchable_attributes=['image'], device=client_obj.device, limit=40) imgs = [r for r in result["hits"]] return imgs # def get_labels_probs(labels, image): # inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) # outputs = model(**inputs) # logits_per_image = outputs.logits_per_image # this is the image-text similarity score # probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities # return probs.tolist()[0] def get_bar_plot(labels, probs): fig, ax = plt.subplots() bar_container = ax.bar(labels, probs) ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1)) ax.bar_label(bar_container, fmt='{:,.4f}') return fig css = """ .gradio-container {background-color: beige} button.gallery-item {background-color: grey} """ # .label {background-color: grey; width: 80px} # h1 {background-color: grey; width: 180px} with gr.Blocks(theme=theme, title="New Look", css=css) as demo: gr.Markdown( """
""") with gr.Tab(label="Search for images"): with gr.Row(): with gr.Column(scale=3): text_input = gr.Text(label="Search with text:") text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1) text_input_1 = gr.Text(label="Search with text:", visible=False) text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False) more_text_search = gr.Button(value="More text fields") text_expanded = gr.State(value=False) with gr.Column(scale=3): best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01) search_button = gr.Button(value="Search") with gr.Column(scale=2): image_input = gr.Image(type="pil", label="Search with image") image_path = gr.State(visible=False) image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) with gr.Row(): with gr.Column(scale=3): images_gallery = gr.Gallery(value=start_page(), columns=4, allow_preview=False, show_label=False, object_fit="contain") with gr.Column(): detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1, height="400",object_fit="contain") image_description = gr.Text(label="Description") product_link = gr.State() page = gr.HTML() def on_new_text_box(more_text_search): # SelectData is a subclass of EventData if more_text_search == "More text fields": return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box") else: return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="More text fields") def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData item = return_item(evt.value) return item[0], item[1], item[2], gr.update(value=" Go to product page ") def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData return return_specific_item(evt.value) # def on_go_to_product_page(product_link): # # try: # return '' more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search]) images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link, page]) detail_gallery.select(on_new_image_to_search, detail_gallery, image_input) # button_go_to_page.click(on_go_to_product_page, product_link, page) # with gr.Tab(label="Search for images"): # labels_input = gr.Text(label="List of labels") # gr.Examples( # ["shirt, dress, shoe", # "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"], # labels_input) # with gr.Row(): # image_labels_input = gr.Image(type="pil", label="Image to compute") # bar_plot = gr.Plot() # with gr.Row(): # gr.Examples( # ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", # "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], # image_labels_input) # gr.Markdown() # compute_button = gr.Button(value="Compute") # response_labels = gr.Text() with gr.Tab(label="Choose dataset"): gr.Markdown("# Choose Dataset") with gr.Row(): list_datasets = gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets", value="New Look Dresses") gr.Markdown() gr.Markdown() with gr.Row(): select_dataset_button = gr.Button("Select") gr.Markdown() gr.Markdown() def on_change_dataset(choice): index_name = "" if choice == "New Look Dresses": index_name = "new_look_expanded_dresses" elif choice == "New Look All": index_name = "new_look_expanded_all" print("Dataset selected: " + index_name) client_obj.set_index_name(index_name) time.sleep(0.5) return choice select_dataset_button.click(on_change_dataset, list_datasets, list_datasets) def load(image_input): if image_input != None: file_name = f"image_to_search.jpg" # file_path = static_dir / file_name file_path = "static/" + file_name print(file_path) image_input.save(file_path) return "https://minderalabs-newlook.hf.space/file=" + file_path else: return "" def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight): # all_queries = [text_input, text_input_1, image_input] all_queries = [text_input, text_input_1, image_path] print(all_queries) all_queries_relevance = [text_relevance, text_relevance_1, image_relevance] print(all_queries_relevance) query_is_none = [True if (query == None or query == "") else False for query in all_queries] print(query_is_none) if sum([1 if query == False else 0 for query in query_is_none]) == 0: empty_response = [None] * 5 empty_response.append("") return [] elif sum([1 if query == False else 0 for query in query_is_none]) == 1: for i in range(3): if query_is_none[i] == False: ### Code to run locally # if i == 2: # load_image(image_input) # query = "/images/images/img_path.jpg" # break ### query = all_queries[i] break else: query = dict() for i in range(3): if query_is_none[i] == False: ### Code to run locally # if i == 2: # load_image(image_input) # query["/images/images/img_path.jpg"] = image_relevance # continue ### query[all_queries[i]] = all_queries_relevance[i] # if text_input == "" and image_input == None: # empty_response = [None] * 5 # empty_response.append("") # return empty_response # if text_input == "": # load_image(image_input) # query = "/images/images/img_path.jpg" # # query = image_path # elif image_input == None: # query = text_input # else: # query = dict() # load_image(image_input) # query["/images/images/img_path.jpg"] = image_relevance # # query[image_path] = image_relevance # query[text_input] = text_relevance list_image_results = [] response = search_images(query, best_seller_score_weight) # for i in range(len(response)): # urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg") # list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg")) return return_results_page(response) # def get_labels(labels_input, image_labels_input): # labels_probs = get_labels_probs(labels_input.split(","), image_labels_input) # bar_plot = get_bar_plot(labels_input.split(","), labels_probs) # return bar_plot, labels_probs # search_button.click( # search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery # ) search_button.click( load, image_input, image_path ).then( search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery] ) # compute_button.click( # get_labels, [labels_input, image_labels_input], [bar_plot, response_labels] # ) demo.queue() demo.launch()