|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
import torch.nn.functional as F |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
from sklearn.manifold import TSNE |
|
import plotly.express as px |
|
import torch |
|
import plotly.io as pio |
|
pio.templates.default = "plotly" |
|
|
|
|
|
st. set_page_config(layout="wide") |
|
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5] _",divider='violet') |
|
st.write("Matryoshka Representation Learning : to learn more :https://aniketrege.github.io/blog/2024/mrl/") |
|
|
|
|
|
@st.cache_data |
|
def get_df(): |
|
prodDf = pd.read_csv("./sample_products.csv") |
|
return prodDf |
|
|
|
@st.cache_resource |
|
def get_nomicModel(): |
|
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) |
|
return model |
|
|
|
def get_searchQueryEmbedding(query): |
|
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True) |
|
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) |
|
return embeddings |
|
|
|
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim): |
|
query_embedNorm = query_embedding[:, :matryoshka_dim] |
|
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1) |
|
loaded_embedNorm = loaded_embed[:, :matryoshka_dim] |
|
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1) |
|
return query_embedNorm,loaded_embedNorm |
|
|
|
def insert_line_breaks(text, interval=30): |
|
words = text.split(' ') |
|
wrapped_text = '' |
|
line_length = 0 |
|
for word in words: |
|
wrapped_text += word + ' ' |
|
line_length += len(word) + 1 |
|
if line_length >= interval: |
|
wrapped_text += '<br>' |
|
line_length = 0 |
|
return wrapped_text.strip() |
|
|
|
|
|
|
|
|
|
model = get_nomicModel() |
|
bigDollEmbedding = get_df()["Description"] |
|
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy")) |
|
|
|
|
|
|
|
|
|
toggle = st.toggle('sample queries') |
|
with st.form("my_form"): |
|
if toggle: |
|
query_input = st.selectbox('select a query:', |
|
('Pack of two assorted boxers, has two pockets, an elasticated waistbandDisclaimer: The final product delivered might vary in colour and prints from the display here.', |
|
'Beige self design shoulder bag, has a zip closure1 main compartment, 3 inner pocketsTwo Handles', |
|
'Set Content: 1 photo frameColour: Black and whiteFrame Pattern: SolidShape: SquareMaterial: Acrylic', |
|
'A pair of dark grey solid boxers, has a slip-on closure with an elasticated waistband and drawstring, two pocket', |
|
'Red & Black solid sweatshirt, has a hood, two pockets, long sleeves, zip closure, straight hem')) |
|
|
|
else: |
|
query_input = st.text_input("") |
|
|
|
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64) |
|
submitted = st.form_submit_button("Submit") |
|
|
|
|
|
|
|
if submitted: |
|
queryEmbedding = get_searchQueryEmbedding(query_input) |
|
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim) |
|
|
|
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T) |
|
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1) |
|
to_index = list(top_indices.numpy()[0]) |
|
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index] |
|
|
|
print(top_values) |
|
|
|
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]}) |
|
df["Product"] = df["Product"].str.replace("search_document:","") |
|
|
|
|
|
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm]) |
|
|
|
tsne = TSNE(n_components=2, random_state=0) |
|
|
|
projections = tsne.fit_transform(allEmbedd) |
|
|
|
listHover = bigDollEmbedding.tolist() |
|
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover] |
|
|
|
|
|
fig = px.scatter( |
|
projections, x=0, y=1, |
|
hover_name=[query_input]+listHover, |
|
|
|
color=["search_query"]+(["search_document"]*270) |
|
) |
|
|
|
col1, col2 = st.columns([2, 2]) |
|
|
|
col2.plotly_chart(fig, use_container_width=True) |
|
col1.dataframe(df) |
|
|
|
st.caption("Dataset Credit : kaggle") |
|
|
|
|
|
|
|
|
|
|
|
|