import transformers import streamlit as st from transformers import AutoTokenizer, AutoModelWithLMHead from transformers import pipeline #tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") sentiment_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-imdb-sentiment") @st.cache def load_model(model_name): model = AutoModelWithLMHead.from_pretrained(model_name) return model def load_text_gen_model(): generator = pipeline("text-generation", model="gpt2-medium") return generator @st.cache def get_sentiment_model(): sentiment_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-imdb-sentiment") return sentiment_model def get_sentiment(text): input_ids = sentiment_tokenizer .encode(text + '', return_tensors='pt') output = sentiment_model.generate(input_ids=input_ids,max_length=2) dec = [sentiment_tokenizer.decode(ids) for ids in output] label = dec[0] return label #@st.cache(allow_output_mutation=True) def get_summarizer(): summarizer = pipeline("summarization", model="facebook/bart-large-cnn") return summarizer def get_qa_model(): model_name = "deepset/roberta-base-squad2" qa_pipeline = pipeline('question-answering', model=model_name, tokenizer=model_name) return qa_pipeline sentiment_model = get_sentiment_model() summarizer = get_summarizer() answer_geerator = get_qa_model() #text_generator = load_text_gen_model() action = st.sidebar.selectbox("Pick an Action", ["Analyse a Review","Generate an Article","Create an Image"]) if action == "Analyse a Review": review = st.text_area("Paste the review here..") if review: #res = text_generator( prompt, max_length=100, temperature=0.7) #st.write(res) if st.button("Get the Sentiment of the Review"): sentiment = get_sentiment(review) st.write(sentiment) if st.button("Summarize the review"): summary = summarizer(review, max_length=130, min_length=30, do_sample=False) st.write(summary) if st.button("Find the key topic"): QA_input = {'question': 'what is the review about?', 'context': review} answer = answer_geerator (QA_input) st.write(answer)