abokbot commited on
Commit
9504094
1 Parent(s): 62b45a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
3
+ import nltk
4
+ import math
5
+ import torch
6
+
7
+ model_name = "abokbot/t5-end2end-questions-generation"
8
+ max_input_length = 512
9
+
10
+ st.header("Generate questions for short Wikipedia-like articles")
11
+
12
+ st_model_load = st.text('Loading question generator model...')
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def load_model():
16
+ print("Loading model...")
17
+ tokenizer = T5TokenizerFast.from_pretrained("t5-base")
18
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
19
+ nltk.download('punkt')
20
+ print("Model loaded!")
21
+ return tokenizer, model
22
+
23
+ tokenizer, model = load_model()
24
+ st.success('Model loaded!')
25
+ st_model_load.text("")
26
+
27
+ if 'text' not in st.session_state:
28
+ st.session_state.text = ""
29
+ st_text_area = st.text_area('Text to generate the questions for', value=st.session_state.text, height=500)
30
+
31
+ def generate_questions():
32
+ st.session_state.text = st_text_area
33
+
34
+ generator_args = {
35
+ "max_length": 256,
36
+ "num_beams": 4,
37
+ "length_penalty": 1.5,
38
+ "no_repeat_ngram_size": 3,
39
+ "early_stopping": True,
40
+ }
41
+ input_string = "generate questions: " + st_text_area + " </s>"
42
+ input_ids = tokenizer.encode(input_string, return_tensors="pt")
43
+ res = model.generate(input_ids, **generator_args)
44
+ output = tokenizer.batch_decode(res, skip_special_tokens=True)
45
+ output = [question.strip() + "?" for question in output[0].split("?") if question != ""]
46
+
47
+ st.session_state.questions = output
48
+
49
+ # generate title button
50
+ st_generate_button = st.button('Generate questions', on_click=generate_questions)
51
+
52
+ # title generation labels
53
+ if 'questions' not in st.session_state:
54
+ st.session_state.questions = []
55
+
56
+ if len(st.session_state.questions) > 0:
57
+ with st.container():
58
+ st.subheader("Generated questions")
59
+ for title in st.session_state.questions:
60
+ st.markdown("__" + title + "__")