SamuelYang commited on
Commit
ea48bc6
1 Parent(s): 1eaba88

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -27
README.md CHANGED
@@ -6,7 +6,7 @@ tags:
6
  - transformers
7
  ---
8
 
9
- ## INF Word-level Sparse Embedding (INF-WSE)
10
 
11
  **INF-WSE** is a series of word-level sparse embedding models developed by [INFLY TECH](https://www.infly.cn/en).
12
  These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most
@@ -29,7 +29,7 @@ relevant information for search and retrieval, particularly in Chinese text.
29
 
30
  ### Transformers
31
 
32
- #### Infer Embeddings
33
  ```python
34
  import torch
35
  from transformers import AutoTokenizer, AutoModel
@@ -58,31 +58,10 @@ print(scores.tolist())
58
 
59
  #### Convert embeddings to lexical weights
60
  ```python
61
- import torch
62
- from transformers import AutoTokenizer, AutoModel
63
  from collections import OrderedDict
64
-
65
- queries = ['电脑一体机由什么构成?', '什么是掌上电脑?']
66
- documents = [
67
- '电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。',
68
- '掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。',
69
- ]
70
- input_texts = queries + documents
71
-
72
- tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False)
73
- model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True)
74
- model.eval()
75
-
76
- max_length = 512
77
-
78
- input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt")
79
-
80
- with torch.no_grad():
81
- embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False)
82
-
83
  def convert_embeddings_to_weights(embeddings, tokenizer):
84
  values, indices = torch.sort(embeddings, dim=-1, descending=True)
85
-
86
  token2weight = []
87
  for i in range(embeddings.size(0)):
88
  token2weight.append(OrderedDict())
@@ -97,14 +76,14 @@ def convert_embeddings_to_weights(embeddings, tokenizer):
97
  return token2weight
98
 
99
  token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
100
- print(token2weight[0])
101
-
102
- # OrderedDict([('一体机', 3.3438382148742676), ('由', 2.493837356567383), ('电脑', 2.0291812419891357), ('构成', 1.986171841621399), ('什么', 1.0218793153762817)])
103
  ```
104
 
105
  ## Evaluation
106
 
107
  ### C-MTEB Retrieval task
 
108
  ([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB))
109
 
110
  Metric: nDCG@10
@@ -114,3 +93,5 @@ Metric: nDCG@10
114
  | [BM25-zh](https://github.com/castorini/pyserini) | - | 25.39 | 13.70 | **86.66** | 13.68 | 11.49 | 15.48 | 6.56 | 29.53 | 25.98 |
115
  | [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 29.94 | **24.50** | 76.16 | 22.12 | 17.62 | 27.52 | 9.78 | **37.69** | 24.12 |
116
  | **inf-wse-v1-base-zh** | 512 | **32.83** | 20.51 | 76.40 | **36.77** | **19.97** | **28.61** | **13.32** | 36.81 | **30.25** |
 
 
 
6
  - transformers
7
  ---
8
 
9
+ ## <u>INF</u> <u>W</u>ord-level <u>S</u>parse <u>E</u>mbedding (INF-WSE)
10
 
11
  **INF-WSE** is a series of word-level sparse embedding models developed by [INFLY TECH](https://www.infly.cn/en).
12
  These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most
 
29
 
30
  ### Transformers
31
 
32
+ #### Infer embeddings
33
  ```python
34
  import torch
35
  from transformers import AutoTokenizer, AutoModel
 
58
 
59
  #### Convert embeddings to lexical weights
60
  ```python
 
 
61
  from collections import OrderedDict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def convert_embeddings_to_weights(embeddings, tokenizer):
63
  values, indices = torch.sort(embeddings, dim=-1, descending=True)
64
+
65
  token2weight = []
66
  for i in range(embeddings.size(0)):
67
  token2weight.append(OrderedDict())
 
76
  return token2weight
77
 
78
  token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
79
+ print(token2weight[1])
80
+ # OrderedDict([('掌上', 3.4572525024414062), ('电脑', 2.6253132820129395), ('是', 2.0787220001220703), ('什么', 1.2899624109268188)])
 
81
  ```
82
 
83
  ## Evaluation
84
 
85
  ### C-MTEB Retrieval task
86
+
87
  ([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB))
88
 
89
  Metric: nDCG@10
 
93
  | [BM25-zh](https://github.com/castorini/pyserini) | - | 25.39 | 13.70 | **86.66** | 13.68 | 11.49 | 15.48 | 6.56 | 29.53 | 25.98 |
94
  | [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 29.94 | **24.50** | 76.16 | 22.12 | 17.62 | 27.52 | 9.78 | **37.69** | 24.12 |
95
  | **inf-wse-v1-base-zh** | 512 | **32.83** | 20.51 | 76.40 | **36.77** | **19.97** | **28.61** | **13.32** | 36.81 | **30.25** |
96
+
97
+ All results, except for BM25, are measured by building the sparse index via [Qdrant](https://github.com/qdrant/qdrant).