PrabakaranC's picture
Update app.py
eb5bdea verified
raw
history blame contribute delete
No virus
4.42 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.manifold import TSNE
import plotly.express as px
import torch
import plotly.io as pio
pio.templates.default = "plotly"
st. set_page_config(layout="wide")
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5] _",divider='violet')
st.write("Matryoshka Representation Learning : to learn more :https://aniketrege.github.io/blog/2024/mrl/")
@st.cache_data
def get_df():
prodDf = pd.read_csv("./sample_products.csv")
return prodDf
@st.cache_resource
def get_nomicModel():
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
return model
def get_searchQueryEmbedding(query):
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True)
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
return embeddings
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim):
query_embedNorm = query_embedding[:, :matryoshka_dim]
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1)
loaded_embedNorm = loaded_embed[:, :matryoshka_dim]
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1)
return query_embedNorm,loaded_embedNorm
def insert_line_breaks(text, interval=30):
words = text.split(' ')
wrapped_text = ''
line_length = 0
for word in words:
wrapped_text += word + ' '
line_length += len(word) + 1
if line_length >= interval:
wrapped_text += '<br>'
line_length = 0
return wrapped_text.strip()
# Automatically wrap the hover text
model = get_nomicModel()
bigDollEmbedding = get_df()["Description"]
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy"))
toggle = st.toggle('sample queries')
with st.form("my_form"):
if toggle:
query_input = st.selectbox('select a query:',
('Pack of two assorted boxers, has two pockets, an elasticated waistbandDisclaimer: The final product delivered might vary in colour and prints from the display here.',
'Beige self design shoulder bag, has a zip closure1 main compartment, 3 inner pocketsTwo Handles',
'Set Content: 1 photo frameColour: Black and whiteFrame Pattern: SolidShape: SquareMaterial: Acrylic',
'A pair of dark grey solid boxers, has a slip-on closure with an elasticated waistband and drawstring, two pocket',
'Red & Black solid sweatshirt, has a hood, two pockets, long sleeves, zip closure, straight hem'))
else:
query_input = st.text_input("")
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64)
submitted = st.form_submit_button("Submit")
if submitted:
queryEmbedding = get_searchQueryEmbedding(query_input)
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim)
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T)
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1)
to_index = list(top_indices.numpy()[0])
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index]
print(top_values)
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]})
df["Product"] = df["Product"].str.replace("search_document:","")
# st.dataframe(df)
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm])
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(allEmbedd)
listHover = bigDollEmbedding.tolist()
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover]
fig = px.scatter(
projections, x=0, y=1,
hover_name=[query_input]+listHover,
color=["search_query"]+(["search_document"]*270)
)
col1, col2 = st.columns([2, 2])
col2.plotly_chart(fig, use_container_width=True)
col1.dataframe(df)
st.caption("Dataset Credit : kaggle")