Adapters
cyn
File size: 458 Bytes
35b487a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

def predict(text):
    tokenizer = AutoTokenizer.from_pretrained("username/model_name")
    model = AutoModelForSequenceClassification.from_pretrained("username/model_name")

    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_class_id = logits.argmax().item()
    return predicted_class_id