File size: 1,829 Bytes
b58c818
1256a85
 
b58c818
ab5688d
 
cd3df30
9fec945
ab5688d
b58c818
 
b4ace98
b58c818
ab5688d
 
cd3df30
ab5688d
 
bd84a02
 
 
9fec945
bd84a02
 
 
 
e7f3263
bd84a02
 
87d90ac
d79cd77
4f75d26
 
 
bd84a02
9fec945
4f75d26
ab5688d
 
bd84a02
d2f7d16
bd84a02
 
 
 
 
 
 
4f75d26
 
 
 
 
 
2b90304
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import transformers
import streamlit as st

from transformers import AutoTokenizer, AutoModelWithLMHead
from transformers import pipeline

#tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
sentiment_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-imdb-sentiment")

@st.cache
def load_model(model_name):
    model = AutoModelWithLMHead.from_pretrained(model_name)
    return model
    
def load_text_gen_model():
    generator = pipeline("text-generation", model="gpt2-medium")
    return generator 
    
@st.cache
def get_sentiment_model():
    sentiment_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-imdb-sentiment")
    return sentiment_model 
    
def get_sentiment(text):
    input_ids = sentiment_tokenizer .encode(text + '</s>', return_tensors='pt')
    output = sentiment_model.generate(input_ids=input_ids,max_length=2)
    dec = [sentiment_tokenizer.decode(ids) for ids in output]
    label = dec[0]
    return label
    
#@st.cache(allow_output_mutation=True) 
def get_summarizer():
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    return summarizer
  
sentiment_model   = get_sentiment_model()
summarizer = get_summarizer()
text_generator = load_text_gen_model()

action = st.sidebar.selectbox("Pick an Action", ["Analyse a Review","Generate an Article","Create an Image"])

if action == "Analyse a Review":
    review = st.text_area("Paste the review here..")
    if review:
        #res = text_generator( prompt, max_length=100, temperature=0.7)
        #st.write(res)
        sentiment = get_sentiment(review)
        st.write(sentiment)
        
        if st.button("Summarize the review"):
            summary = summarizer(ARTICLE, max_length=130, min_length=30, do_sample=False)
            st.write(summary)