|
--- |
|
language: |
|
- zh |
|
base_model: junnyu/roformer_chinese_base |
|
tags: |
|
- transformers |
|
--- |
|
|
|
## <u>INF</u> <u>W</u>ord-level <u>S</u>parse <u>E</u>mbedding (INF-WSE) |
|
|
|
**INF-WSE** is a series of word-level sparse embedding models developed by [INFLY TECH](https://www.infly.cn/en). |
|
These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most |
|
relevant information for search and retrieval, particularly in Chinese text. |
|
|
|
### Key Features: |
|
|
|
- **Optimized for Retrieval**: INF-WSE is designed with retrieval tasks in mind. The sparse embeddings enable efficient |
|
matching between queries and documents, making it highly effective for semantic search, ranking, and information |
|
retrieval scenarios where speed and accuracy are critical. |
|
- **Word-level Sparse Embeddings**: The model generates sparse representations at the word level, capturing essential |
|
semantic details that help improve the relevance of search results. This is particularly useful for Chinese language |
|
retrieval tasks, where word segmentation can significantly impact performance. |
|
- **Sparse Representation for Efficiency**: Unlike dense embeddings that have a fixed number of dimensions, INF-WSE |
|
produces sparse embeddings where the dimensionality matches the vocabulary size. Most dimensions are set to zero, |
|
focusing only on the most significant terms. This sparsity reduces the computational load, enabling faster retrieval |
|
without compromising on precision. |
|
|
|
## Usage |
|
|
|
### Transformers |
|
|
|
#### Infer embeddings |
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
queries = ['电脑一体机由什么构成?', '什么是掌上电脑?'] |
|
documents = [ |
|
'电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。', |
|
'掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。', |
|
] |
|
input_texts = queries + documents |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False) # Fast tokenizer has not been supported yet |
|
model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True) |
|
model.eval() |
|
|
|
max_length = 512 |
|
|
|
input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False) # if return_sparse=True, return sparse tensor, else return dense tensor |
|
|
|
scores = embeddings[:2] @ embeddings[2:].T |
|
print(scores.tolist()) |
|
# [[21.224790573120117, 4.520412921905518], [10.290857315063477, 19.359437942504883]] |
|
``` |
|
|
|
#### Convert embeddings to lexical weights |
|
```python |
|
from collections import OrderedDict |
|
def convert_embeddings_to_weights(embeddings, tokenizer): |
|
values, indices = torch.sort(embeddings, dim=-1, descending=True) |
|
|
|
token2weight = [] |
|
for i in range(embeddings.size(0)): |
|
token2weight.append(OrderedDict()) |
|
|
|
non_zero_mask = values[i] != 0 |
|
tokens = tokenizer.convert_ids_to_tokens(indices[i][non_zero_mask]) |
|
weights = values[i][non_zero_mask].tolist() |
|
|
|
for token, weight in zip(tokens, weights): |
|
token2weight[i][token] = weight |
|
|
|
return token2weight |
|
|
|
token2weight = convert_embeddings_to_weights(embeddings, tokenizer) |
|
print(token2weight[1]) |
|
# OrderedDict([('掌上', 3.4572525024414062), ('电脑', 2.6253132820129395), ('是', 2.0787220001220703), ('什么', 1.2899624109268188)]) |
|
``` |
|
|
|
## Evaluation |
|
|
|
### C-MTEB Retrieval task |
|
|
|
([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB)) |
|
|
|
Metric: nDCG@10 |
|
|
|
| Model Name | Max Length | Average | Cmedqa | Covid | Du | Ecom | Medical | MMarco | T2 | Video | |
|
|:---------------------------------------------------:|:----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| |
|
| [BM25-zh](https://github.com/castorini/pyserini) | - | 25.39 | 13.70 | **86.66** | 13.68 | 11.49 | 15.48 | 6.56 | 29.53 | 25.98 | |
|
| [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 29.94 | **24.50** | 76.16 | 22.12 | 17.62 | 27.52 | 9.78 | **37.69** | 24.12 | |
|
| **inf-wse-v1-base-zh** | 512 | **32.83** | 20.51 | 76.40 | **36.77** | **19.97** | **28.61** | **13.32** | 36.81 | **30.25** | |
|
|
|
All results, except for BM25, are measured by building the sparse index via [Qdrant](https://github.com/qdrant/qdrant). |