Edit model card

bert-base-uncased model fine-tuned on SST-2

This model was created using the nn_pruning python library: the linear layers contains 37% of the original weights.

The model contains 51% of the original weights overall (the embeddings account for a significant part of the model, and they are not pruned by this method).

In terms of perfomance, its accuracy is 91.17.

Fine-Pruning details

This model was fine-tuned from the HuggingFace model checkpoint on task, and distilled from the model textattack/bert-base-uncased-SST-2. This model is case-insensitive: it does not make a difference between english and English.

A side-effect of the block pruning method is that some of the attention heads are completely removed: 88 heads were removed on a total of 144 (61.1%). Here is a detailed view on how the remaining heads are distributed in the network after pruning.

Details of the SST-2 dataset

Dataset Split # samples
SST-2 train 67K
SST-2 eval 872

Results

Pytorch model file size: 351MB (original BERT: 420MB)

Metric # Value # Original (Table 2) Variation
accuracy 91.17 92.7 -1.53

Example Usage

Install nn_pruning: it contains the optimization script, which just pack the linear layers into smaller ones by removing empty rows/columns.

pip install nn_pruning

Then you can use the transformers library almost as usual: you just have to call optimize_model when the pipeline has loaded.

from transformers import pipeline
from nn_pruning.inference_model_patcher import optimize_model

cls_pipeline = pipeline(
    "text-classification",
    model="echarlaix/bert-base-uncased-sst2-acc91.1-d37-hybrid",
    tokenizer="echarlaix/bert-base-uncased-sst2-acc91.1-d37-hybrid",
)

print(f"Parameters count (includes only head pruning, no feed forward pruning)={int(cls_pipeline.model.num_parameters() / 1E6)}M")
cls_pipeline.model = optimize_model(cls_pipeline.model, "dense")
print(f"Parameters count after optimization={int(cls_pipeline.model.num_parameters() / 1E6)}M")
predictions = cls_pipeline("This restaurant is awesome")
print(predictions)
Downloads last month
139
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train echarlaix/bert-base-uncased-sst2-acc91.1-d37-hybrid