fashion_classification / data_utils.py
Ceyda Cinarel
cached demo
c8f27cf
raw
history blame contribute delete
No virus
1.9 kB
from datasets import load_dataset
from PIL import Image
import os
import pandas as pd
from transformers import AutoFeatureExtractor,AutoModel
from faiss.contrib.inspect_tools import get_flat_data
import pymde
import numpy as np
def get_embedding(model_name,viz_dat):
index_file=f"./indexes/{model_name.split('/')[1]}.faiss"
if os.path.exists(index_file):
viz_dat.load_faiss_index('embeddings', index_file)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# model.to("cuda:0")
def embed(x):
images=x["image"]
inputs = feature_extractor(images=images, return_tensors="pt")
# inputs.to("cuda:0")
outputs = model(**inputs,output_hidden_states= True)
final_emb=outputs.pooler_output.detach().cpu().numpy() # this line depends on the model you are using
x["embeddings"]=final_emb
return x
# Add embeddings to dataset
viz_dat = viz_dat.map(embed,batched=True,batch_size=20)
viz_dat.add_faiss_index(column='embeddings')
viz_dat.save_faiss_index('embeddings',index_file)
embedding_file=f"./indexes/{model_name.split('/')[1]}.npy"
if os.path.exists(embedding_file):
embedding = np.load(embedding_file) # load
else:
index=viz_dat.get_index("embeddings").faiss_index
embeddings=get_flat_data(index)
embedding=pymde.preserve_neighbors(embeddings, verbose=True).embed().numpy()
np.save(embedding_file, embedding) # save
embedding=pd.DataFrame(embedding,columns=["x","y"])
embedding["image"]=viz_dat["image"]
embedding["gender"]=viz_dat["gender"]
embedding["masterCategory"]=viz_dat["masterCategory"]
embedding["subCategory"]=viz_dat["subCategory"]
return embedding