Edit model card

Cross-Encoder for Natural Language Inference

This model was trained using SentenceTransformers Cross-Encoder class. This model is based on microsoft/deberta-v3-xsmall

Training Data

The model was trained on the SNLI and MultiNLI datasets. For a given sentence pair, it will output three scores corresponding to the labels: contradiction, entailment, neutral.

Performance

  • Accuracy on SNLI-test dataset: 91.64
  • Accuracy on MNLI mismatched set: 87.77

For futher evaluation results, see SBERT.net - Pretrained Cross-Encoder.

Usage

Pre-trained models can be used like this:

from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/nli-deberta-v3-xsmall')
scores = model.predict([('A man is eating pizza', 'A man eats something'), ('A black race car starts up in front of a crowd of people.', 'A man is driving down a lonely road.')])

#Convert scores to labels
label_mapping = ['contradiction', 'entailment', 'neutral']
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]

Usage with Transformers AutoModel

You can use the model also directly with Transformers library (without SentenceTransformers library):

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-deberta-v3-xsmall')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-deberta-v3-xsmall')

features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'],  padding=True, truncation=True, return_tensors="pt")

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    label_mapping = ['contradiction', 'entailment', 'neutral']
    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
    print(labels)

Zero-Shot Classification

This model can also be used for zero-shot-classification:

from transformers import pipeline

classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-deberta-v3-xsmall')

sent = "Apple just announced the newest iPhone X"
candidate_labels = ["technology", "sports", "politics"]
res = classifier(sent, candidate_labels)
print(res)
Downloads last month
3,923
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train cross-encoder/nli-deberta-v3-xsmall

Space using cross-encoder/nli-deberta-v3-xsmall 1