Vrk commited on
Commit
78919c7
1 Parent(s): 7c251b7
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import spacy
4
+ # from spacy.lang.en import English
5
+ # from utils import spacy_function, make_predictions, example_input
6
+
7
+ import sys
8
+ sys.path.insert(1, 'PyTorch')
9
+ from Dataset import SkimlitDataset
10
+ from Embeddings import get_embeddings
11
+ from Model import SkimlitModel
12
+ from Tokenizer import Tokenizer
13
+ from LabelEncoder import LabelEncoder
14
+ from MakePredictions import make_skimlit_predictions
15
+ from RandomAbstract import Choose_Random_text
16
+
17
+ MODEL_PATH = 'PyTorch/utils/skimlit-model-final-1.pt'
18
+ TOKENIZER_PATH = 'PyTorch/utils/tokenizer.json'
19
+ LABEL_ENOCDER_PATH = "PyTorch/utils/label_encoder.json"
20
+ EMBEDDING_FILE_PATH = 'PyTorch/utils/glove.6B.300d.txt'
21
+
22
+ @st.cache()
23
+ def create_utils(model_path, tokenizer_path, label_encoder_path, embedding_file_path):
24
+ tokenizer = Tokenizer.load(fp=tokenizer_path)
25
+ label_encoder = LabelEncoder.load(fp=label_encoder_path)
26
+ embedding_matrix = get_embeddings(embedding_file_path, tokenizer, 300)
27
+ model = SkimlitModel(embedding_dim=300, vocab_size=len(tokenizer), hidden_dim=128, n_layers=3, linear_output=128, num_classes=len(label_encoder), pretrained_embeddings=embedding_matrix)
28
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
29
+ print(model)
30
+ return model, tokenizer, label_encoder
31
+
32
+ def model_prediction(abstract, model, tokenizer, label_encoder):
33
+ objective = ''
34
+ background = ''
35
+ method = ''
36
+ conclusion = ''
37
+ result = ''
38
+
39
+ lines, pred = make_skimlit_predictions(abstract, model, tokenizer, label_encoder)
40
+ # pred, lines = make_predictions(abstract)
41
+
42
+ for i, line in enumerate(lines):
43
+ if pred[i] == 'OBJECTIVE':
44
+ objective = objective + line
45
+
46
+ elif pred[i] == 'BACKGROUND':
47
+ background = background + line
48
+
49
+ elif pred[i] == 'METHODS':
50
+ method = method + line
51
+
52
+ elif pred[i] == 'RESULTS':
53
+ result = result + line
54
+
55
+ elif pred[i] == 'CONCLUSIONS':
56
+ conclusion = conclusion + line
57
+
58
+ return objective, background, method, conclusion, result
59
+
60
+
61
+
62
+ def main():
63
+
64
+ st.set_page_config(
65
+ page_title="SkimLit",
66
+ page_icon="📄",
67
+ layout="wide",
68
+ initial_sidebar_state="expanded"
69
+ )
70
+
71
+ st.title('SkimLit📄🔥')
72
+ st.caption('An NLP model to classify abstract sentences into the role they play (e.g. objective, methods, results, etc..) to enable researchers to skim through the literature and dive deeper when necessary.')
73
+
74
+ # creating model, tokenizer and labelEncoder
75
+ # if PREP_MODEL:
76
+ # skimlit_model, tokenizer, label_encoder = create_utils(MODEL_PATH, TOKENIZER_PATH, LABEL_ENOCDER_PATH, EMBEDDING_FILE_PATH)
77
+ # PREP_MODEL = False
78
+
79
+ col1, col2 = st.columns(2)
80
+
81
+ with col1:
82
+ st.write('#### Entre Abstract Here !!')
83
+ abstract = st.text_area(label='', height=200)
84
+
85
+ agree = st.checkbox('Show Example Abstract')
86
+ predict = st.button('Extract !')
87
+
88
+ if agree:
89
+ example_input = Choose_Random_text()
90
+ st.info(example_input)
91
+
92
+ # make prediction button logic
93
+ if predict:
94
+ with col2:
95
+ with st.spinner('Wait for prediction....'):
96
+ skimlit_model, tokenizer, label_encoder = create_utils(MODEL_PATH, TOKENIZER_PATH, LABEL_ENOCDER_PATH, EMBEDDING_FILE_PATH)
97
+ objective, background, methods, conclusion, result = model_prediction(abstract, skimlit_model, tokenizer, label_encoder)
98
+
99
+ st.markdown(f'### Objective : ')
100
+ st.info(objective)
101
+ # st.write(f'{objective}')
102
+ st.markdown(f'### Background : ')
103
+ st.info(background)
104
+ # st.write(f'{background}')
105
+ st.markdown(f'### Methods : ')
106
+ st.info(methods)
107
+ # st.write(f'{methods}')
108
+ st.markdown(f'### Result : ')
109
+ st.info(result)
110
+ # st.write(f'{result}')
111
+ st.markdown(f'### Conclusion : ')
112
+ st.info(conclusion)
113
+ # st.write(f'{conclusion}')
114
+
115
+
116
+
117
+ if __name__=='__main__':
118
+ main()