"""Votek Retriever.""" import json import os import random from collections import defaultdict from typing import Optional import numpy as np from sklearn.metrics.pairwise import cosine_similarity from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever class VotekRetriever(TopkRetriever): """Vote-k In-context Learning Retriever, subclass of `TopkRetriever`. **WARNING**: This class has not been tested thoroughly. Please use it with caution. """ def __init__(self, dataset, ice_separator: Optional[str] = '\n', ice_eos_token: Optional[str] = '\n', ice_num: Optional[int] = 1, sentence_transformers_model_name: Optional[ str] = 'all-mpnet-base-v2', tokenizer_name: Optional[str] = 'gpt2-xl', batch_size: Optional[int] = 1, votek_k: Optional[int] = 3) -> None: super().__init__(dataset, ice_separator, ice_eos_token, ice_num, sentence_transformers_model_name, tokenizer_name, batch_size) self.votek_k = votek_k def votek_select(self, embeddings=None, select_num=None, k=None, overlap_threshold=None, vote_file=None): n = len(embeddings) if vote_file is not None and os.path.isfile(vote_file): with open(vote_file, encoding='utf-8') as f: vote_stat = json.load(f) else: vote_stat = defaultdict(list) for i in range(n): cur_emb = embeddings[i].reshape(1, -1) cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1) sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1] for idx in sorted_indices: if idx != i: vote_stat[idx].append(i) if vote_file is not None: with open(vote_file, 'w', encoding='utf-8') as f: json.dump(vote_stat, f) votes = sorted(vote_stat.items(), key=lambda x: len(x[1]), reverse=True) j = 0 selected_indices = [] while len(selected_indices) < select_num and j < len(votes): candidate_set = set(votes[j][1]) flag = True for pre in range(j): cur_set = set(votes[pre][1]) if len(candidate_set.intersection( cur_set)) >= overlap_threshold * len(candidate_set): flag = False break if not flag: j += 1 continue selected_indices.append(int(votes[j][0])) j += 1 if len(selected_indices) < select_num: unselected_indices = [] cur_num = len(selected_indices) for i in range(n): if i not in selected_indices: unselected_indices.append(i) selected_indices += random.sample(unselected_indices, select_num - cur_num) return selected_indices def vote_k_search(self): vote_k_idxs = self.votek_select(embeddings=self.embed_list, select_num=self.ice_num, k=self.votek_k, overlap_threshold=1) return [vote_k_idxs[:] for _ in range(len(self.test_ds))] def retrieve(self): return self.vote_k_search()