offenseval / app.py
swastik-kapture's picture
Update app.py
afc995c verified
raw
history blame contribute delete
No virus
1.75 kB
import os
import time
import streamlit as st
from dotenv import load_dotenv
from transformers import pipeline
# Model to load
MODEL_TO_LOAD = "swastik-kapture/offenseval-xlmr-v3"
TOKENIZER = "xlm-roberta-base"
# create classification pipeline
trained_model = pipeline("text-classification", model=MODEL_TO_LOAD, tokenizer=TOKENIZER, token=os.environ.get("HF_READ_KEY"))
# Streamlit App
def main():
# create a session state for conversation history
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
# streamlit title
st.title("OffensEval: Profanity Detection")
# user message
user_message = st.chat_input("Say something")
# if user input is present try to predict the outcome
if user_message:
# append user message to history
st.session_state.conversation_history.append(('user', user_message, time.time()))
# get predicted output
output = trained_model.predict(user_message)
# get predictied label and score
label = output[0]['label']
score = output[0]['score']
# default color
color = "white"
# get the color based on label
if label == "not offensive":
color = "green"
elif label == "offensive":
color = "red"
st.session_state.conversation_history.append(('assistant', f"<div style='background-color: {color}; width: auto; height: 50px;'>Label: {label}; Score: {score:.2f}</div>", time.time()))
# Display chat history
for sender, message, timestamp in st.session_state.conversation_history:
with st.chat_message(sender):
st.write(message, unsafe_allow_html=True)
if __name__ == "__main__":
main()