abokbot commited on
Commit
36fdc32
1 Parent(s): 62f35a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
3
+ import nltk
4
+ import math
5
+ import torch
6
+
7
+ model_name = "abokbot/t5-end2end-questions-generation"
8
+
9
+ st.header("Generate questions for short Wikipedia-like articles")
10
+
11
+ st_model_load = st.text('Loading question generator model...')
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_model():
15
+ print("Loading model...")
16
+ tokenizer = T5TokenizerFast.from_pretrained("t5-base")
17
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
18
+ nltk.download('punkt')
19
+ print("Model loaded!")
20
+ return tokenizer, model
21
+
22
+ tokenizer, model = load_model()
23
+ st.success('Model loaded!')
24
+ st_model_load.text("")
25
+
26
+ if 'text' not in st.session_state:
27
+ st.session_state.text = ""
28
+ st_text_area = st.text_area('Text to generate the questions for', value=st.session_state.text, height=500)
29
+
30
+ def generate_questions():
31
+ st.session_state.text = st_text_area
32
+
33
+ generator_args = {
34
+ "max_length": 256,
35
+ "num_beams": 4,
36
+ "length_penalty": 1.5,
37
+ "no_repeat_ngram_size": 3,
38
+ "early_stopping": True,
39
+ }
40
+ input_string = "generate questions: " + st_text_area + " </s>"
41
+ input_ids = tokenizer.encode(input_string, return_tensors="pt")
42
+ res = model.generate(input_ids, **generator_args)
43
+ output = tokenizer.batch_decode(res, skip_special_tokens=True)
44
+ output = [question.strip() + "?" for question in output[0].split("?") if question != ""]
45
+
46
+ st.session_state.questions = output
47
+
48
+ # generate title button
49
+ st_generate_button = st.button('Generate questions', on_click=generate_questions)
50
+
51
+ # title generation labels
52
+ if 'questions' not in st.session_state:
53
+ st.session_state.questions = []
54
+
55
+ if len(st.session_state.questions) > 0:
56
+ with st.container():
57
+ st.subheader("Generated questions")
58
+ for title in st.session_state.questions:
59
+ st.markdown("__" + title + "__")