from transformers import TFAutoModelForSequenceClassification, AutoTokenizer from transformers import TextClassificationPipeline import tensorflow as tf import gradio as gr # Load the model and tokenizer model = TFAutoModelForSequenceClassification.from_pretrained("AiresPucrs/distilbert-base-cased-sentiment-classifier") tokenizer = AutoTokenizer.from_pretrained("AiresPucrs/distilbert-base-cased-sentiment-classifier") def get_gradients(text, model, tokenizer): embedding_matrix = model.distilbert.embeddings.weights[0] vocab_size = embedding_matrix.shape[0] encoded_tokens = tokenizer(text, return_tensors="tf") token_ids = list(encoded_tokens["input_ids"].numpy()[0]) token_ids_tensor = tf.constant([token_ids], dtype='int32') token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(token_ids_tensor_one_hot) inputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix) logits = model({"inputs_embeds": inputs_embeds, "attention_mask": encoded_tokens["attention_mask"] } ).logits prediction_class = tf.argmax(logits, axis=1).numpy()[0] target_logit = logits[0][prediction_class] gradient_non_normalized = tf.norm( tape.gradient(target_logit, token_ids_tensor_one_hot),axis=2) gradient_tensor = ( gradient_non_normalized / tf.reduce_max(gradient_non_normalized) )[0].numpy().tolist() token_words = tokenizer.convert_ids_to_tokens(token_ids) return gradient_tensor, token_words, prediction_class def interpet_DistilBERT(text): gradient_tensor, token_words, prediction_class = get_gradients(text, model, tokenizer) token_words = token_words[1:-1] gradient_tensor = gradient_tensor[1:-1] total = sum(gradient_tensor) normalized_gradient_tensor = [x/total for x in gradient_tensor] output = f"Predicted Answer:
{model.config.id2label[int(prediction_class)]}

Gradient Scores:" return output, {token_words[i]: normalized_gradient_tensor[i] for i in range(len(normalized_gradient_tensor))} description = ( "
" "

Explaining DistilBERT with integrated gradients 🏰

" "
This app was built to provide insight into how a DistilBERT model, fine-tuned for text classification, operates via integrated gradient explanations. Enter a text and see the gradient scores for each word in the input.
" "To learn more, visit this tutorial
" "
" ) article = ( "
" "Return to the castle." "
" ) # Create the Gradio interface interface = gr.Interface( fn=interpet_DistilBERT, inputs=gr.Textbox(placeholder="Enter text here..."), outputs=["html", gr.Label(num_top_classes=10)], allow_flagging="never", description=description, article=article, ) # Launch the Gradio interface interface.launch()