Tahsin-Mayeesha commited on
Commit
6ac74f2
1 Parent(s): ff0dcf7
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -1,12 +1,20 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- model_name = "Tahsin-Mayeesha/squad-bn-mt5-base2"
4
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
5
- tokenizer = AutoTokenizer.from_pretrained(model_name)
6
 
 
 
 
 
 
 
 
7
 
8
- import gradio as gr
9
- def generate__questions(context,answer):
 
 
 
10
  text='answer: '+answer + ' context: ' + context
11
  text_encoding = tokenizer.encode_plus(
12
  text,return_tensors="pt"
@@ -22,7 +30,8 @@ def generate__questions(context,answer):
22
 
23
  return tokenizer.decode(generated_ids[0],skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')
24
 
25
- demo = gr.Interface(fn=generate__questions, inputs=[gr.Textbox(label='Context'),
 
26
  gr.Textbox(label='Answer')] ,
27
  outputs=gr.Textbox(label='Question'),
28
  title="Bangla Question Generation",
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import gradio as gr
 
 
4
 
5
+ def choose_model(model_choice):
6
+ if model_choice=="mt5-small":
7
+ return "jannatul17/squad-bn-qgen-mt5-small-v1"
8
+ elif model_choice=="mt5-base":
9
+ return "Tahsin-Mayeesha/squad-bn-mt5-base2"
10
+ else :
11
+ return "jannatul17/squad-bn-qgen-banglat5-v1"
12
 
13
+
14
+ def generate__questions(model_choice,context,answer):
15
+ model_name = choose_model(model_choice)
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  text='answer: '+answer + ' context: ' + context
19
  text_encoding = tokenizer.encode_plus(
20
  text,return_tensors="pt"
 
30
 
31
  return tokenizer.decode(generated_ids[0],skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')
32
 
33
+ demo = gr.Interface(fn=generate__questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"),
34
+ gr.Textbox(label='Context'),
35
  gr.Textbox(label='Answer')] ,
36
  outputs=gr.Textbox(label='Question'),
37
  title="Bangla Question Generation",