import torch from transformers import AutoTokenizer, AutoModel import os class TextExtractor: def __init__(self, model_name, proxy=None): """ Initialize the TextExtractor with a specified model and optional proxy settings. Parameters: - model_name (str): The name of the pre-trained model to load from HuggingFace Hub. - proxy (str, optional): The proxy address to use for HTTP and HTTPS requests. """ if proxy is None: proxy = 'http://localhost:8234' if proxy: os.environ['HTTP_PROXY'] = proxy os.environ['HTTPS_PROXY'] = proxy try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) except: print('try switch on local_files_only') self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) self.model = AutoModel.from_pretrained(model_name, local_files_only=True) self.model.eval() def extract(self, sentences): """ Extract sentence embeddings for the provided sentences. Parameters: - sentences (list of str): A list of sentences to extract embeddings for. Returns: - torch.Tensor: The normalized sentence embeddings. """ encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = self.model(**encoded_input) sentence_embeddings = model_output[0][:, 0] sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings import pandas as pd def get_qas(excel_file = None): defaule_excel_file = 'data/output_fixid.xlsx' if excel_file is None: excel_file = defaule_excel_file # 读取Excel文件 df = pd.read_excel(excel_file) df = df[df["question"].notna()] df = df[df["summary"].notna()] datas = [] # 遍历DataFrame的每一行 for index, row in df.iterrows(): id = row['id'] question = row['question'] short_answer = row['summary'] category = row['category'] texts = [question, short_answer] data_value = { "texts":texts, } data = { "id":id, "value":data_value } datas.append(data) return datas from tqdm import tqdm def extract_embedding(datas, text_extractor): """ Extract embeddings for each item in the provided data. Parameters: - datas (list of dict): A list of dictionaries containing text data. Returns: - list of dict: The input data with added embeddings. """ for data in tqdm(datas): texts = data["value"]["texts"] text = "。".join(texts) embeddings = text_extractor.extract(text) embeddings_list = embeddings.tolist() # Convert tensor to list of lists data["value"]["embedding"] = embeddings_list return datas def save_parquet(datas, file_path): """ Save the provided data to a Parquet file. Parameters: - datas (list of dict): A list of dictionaries containing text data and embeddings. - file_path (str): The path to the output Parquet file. """ # Flatten the data for easier conversion to DataFrame flattened_data = [] for data in datas: id = data["id"] texts = data["value"]["texts"] text = "。".join(texts) embedding = data["value"]["embedding"] flattened_data.append({ "id": id, "text": text, "embedding": embedding }) # Create DataFrame df = pd.DataFrame(flattened_data) # Save DataFrame to Parquet df.to_parquet(file_path, index=False) import pandas as pd import os def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'): """ Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified. Parameters: - parquet_file (str): The path to the Parquet file. - regen (bool): Whether to regenerate embeddings. Returns: - dict: A dictionary mapping IDs to list of float embeddings. """ if regen or not os.path.exists(parquet_file): print("Regenerating embeddings...") # Example usage: model_name = 'BAAI/bge-small-zh-v1.5' text_extractor = TextExtractor(model_name) datas = get_qas() print("Extracting embeddings for", len(datas), "data items") datas = extract_embedding(datas, text_extractor) save_parquet(datas, parquet_file) df = pd.read_parquet(parquet_file) id2embedding = {} for index, row in df.iterrows(): id = row['id'] embedding = row['embedding'] id2embedding[id] = embedding[0] return id2embedding import torch from sklearn.metrics.pairwise import cosine_similarity import heapq def __get_id2top30map(id2embedding): """ Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity. Parameters: - id2embedding (dict): A dictionary mapping IDs to list of float embeddings. Returns: - dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs. """ ids = list(id2embedding.keys()) embeddings = torch.tensor([id2embedding[id] for id in ids]) # Compute cosine similarity matrix cos_sim_matrix = cosine_similarity(embeddings) id2top30map = {} for i, id in enumerate(ids): # Get the similarity scores for the current ID sim_scores = cos_sim_matrix[i] # Get the top 30 indices (excluding the current ID itself) top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x]) top_indices.remove(i) # Remove the index of the current ID # Map the indices back to IDs top_30_ids = [ids[idx] for idx in top_indices[:30]] id2top30map[id] = top_30_ids return id2top30map import pickle def get_id2top30map( id2embedding = None ): default_save_pkl = "data/id2top30map.pkl" if id2embedding is None: if os.path.exists(default_save_pkl): with open(default_save_pkl, 'rb') as f: id2top30map = pickle.load(f) else: print("No embedding found, generating new one...") id2embedding = get_id2embedding(regen=False) id2top30map = __get_id2top30map(id2embedding) with open(default_save_pkl, 'wb') as f: pickle.dump(id2top30map, f) else: id2top30map = __get_id2top30map(id2embedding) return id2top30map if __name__ == '__main__': if False: # Example usage: model_name = 'BAAI/bge-small-zh-v1.5' sentences = ["样例数据-1", "样例数据-2"] text_extractor = TextExtractor(model_name) embeddings = text_extractor.extract(sentences) print("Sentence embeddings:", embeddings) datas = get_qas() print("extract embedding for ", len(datas), " datas") datas = extract_embedding(datas, text_extractor ) default_parquet_save_name = "data/qa_with_embedding.parquet" save_parquet(datas, default_parquet_save_name) if True: id2embedding = get_id2embedding(regen=False) print(len(id2embedding[4])) id2top30map = get_id2top30map( None ) print("ID to Top 30 Neighbors dictionary:", id2top30map[4]) if True: start_id = 332 visited_ids = [start_id] current_queue = [start_id] expend_num = 5 for iteration in range(10): current_node = current_queue.pop(0) top30 = id2top30map[current_node] current_expend = [] for id in top30: if id not in visited_ids: visited_ids.append(id) current_queue.append(id) current_expend.append(id) if len(current_expend) >= expend_num: break display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend]) print(display_text) from get_qa_and_image import get_qa_and_image image_datas = get_qa_and_image() id2index = {} for i, data in enumerate(image_datas): id2index[data['id']] = i indexes = [id2index[i] for i in visited_ids if i in id2index] image_names = [image_datas[index]['value']['image'] for index in indexes] target_copy_folder = "data/asso_collection" import shutil # copy image into target_copy_folder for image_name in image_names: shutil.copy(image_name, target_copy_folder)