File size: 4,751 Bytes
eae27ac
 
6ac74f2
5c7018b
 
 
cb4650d
 
 
5eea717
eae27ac
90b8c65
 
 
6ac74f2
 
 
 
 
 
 
b710584
6ac74f2
2f00cc5
6ac74f2
 
 
ff0dcf7
eae27ac
 
 
 
 
 
 
2f00cc5
 
 
 
 
 
 
eae27ac
 
2f00cc5
 
 
 
5eea717
 
 
 
 
 
cb4650d
5eea717
 
2f00cc5
 
 
6ac74f2
2f00cc5
 
 
 
 
 
 
 
 
10f2f6d
2f00cc5
5eea717
90b8c65
 
10f2f6d
2f00cc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

import pickle

vectorizer = pickle.load(open("tfidf.pickle", "rb"))
# clf = pickle.load(open("classifier.pickle", "rb"))

example_context = "ফলস্বরূপ, ১৯৭৯ সালে, সনি এবং ফিলিপস একটি নতুন ডিজিটাল অডিও ডিস্ক ডিজাইন করার জন্য প্রকৌশলীদের একটি যৌথ টাস্ক ফোর্স গঠন করে। ইঞ্জিনিয়ার কিস শুহামার ইমমিনক এবং তোশিতাদা দোই এর নেতৃত্বে, গবেষণাটি লেজার এবং অপটিক্যাল ডিস্ক প্রযুক্তিকে এগিয়ে নিয়ে যায়। এক বছর পরীক্ষা-নিরীক্ষা ও আলোচনার পর টাস্ক ফোর্স রেড বুক সিডি-ডিএ স্ট্যান্ডার্ড তৈরি করে। প্রথম প্রকাশিত হয় ১৯৮০ সালে। আইইসি কর্তৃক ১৯৮৭ সালে আন্তর্জাতিক মান হিসেবে আনুষ্ঠানিকভাবে এই মান গৃহীত হয় এবং ১৯৯৬ সালে বিভিন্ন সংশোধনী মানের অংশ হয়ে ওঠে।'"
example_answer = "১৯৮০"

def choose_model(model_choice):
  if model_choice=="mt5-small":
    return "jannatul17/squad-bn-qgen-mt5-small-v1"
  elif model_choice=="mt5-base":
    return "Tahsin-Mayeesha/squad-bn-mt5-base2"
  else :
    return "jannatul17/squad-bn-qgen-banglat5-v1"


def generate_questions(model_choice,context,answer,numReturnSequences=1,num_beams=None,do_sample=False,top_p=None,top_k=None,temperature=None):
  model_name = choose_model(model_choice)
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
  tokenizer = AutoTokenizer.from_pretrained(model_name)  
  text='answer: '+answer + ' context: ' + context
  text_encoding = tokenizer.encode_plus(
      text,return_tensors="pt"
  )
  model.eval()
  generated_ids =  model.generate(
    input_ids=text_encoding['input_ids'],
    attention_mask=text_encoding['attention_mask'],
    max_length=120,
    num_beams=num_beams,
    do_sample=do_sample,
    top_k = top_k,
    top_p = top_p,
    temperature = temperature,
    num_return_sequences=numReturnSequences
  )
  
  text = []
  for id in generated_ids:
    text.append(tokenizer.decode(id,skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' '))
  
  #question = " ".join(text) 
  #correctness_pred = clf.predict(vectorizer.transform([question]))[0]
  #if correctness_pred == 1:
  #  correctness = "Correct"
  #else : 
  #  correctness = "Incorrect"
  
  #return question, correctness 
   return question 
  
  
demo = gr.Interface(fn=generate_questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"),
                                                    gr.Textbox(label='Context'),
                                                    gr.Textbox(label='Answer'),
                                                    # hyperparameters
                                                    gr.Slider(1, 3, 1, step=1, label="Num return Sequences"),
                                                    # beam search
                                                    gr.Slider(1, 10,value=None, step=1, label="Beam width"),
                                                    # top-k/top-p
                                                    gr.Checkbox(label="Do Random Sample",value=False),
                                                    gr.Slider(0, 50, value=None, step=1, label="Top K"),
                                                    gr.Slider(0, 1, value=None, label="Top P/Nucleus Sampling"),
                                                    gr.Slider(0, 1, value=None, label="Temperature") ] ,
                                                    # output
                                                    outputs=[gr.Textbox(label='Question')],
                                                    examples=[["banglat5",example_context,example_answer]],
                                                    cache_examples=False,
                                                    title="Bangla Question Generation")
demo.launch()