from typing import List, Tuple, Dict, Any import time import json import requests import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer with open("config.json", "r") as f: config = json.load(f) max_questions_count = config["MAX_QUESTIONS_COUNT"] max_tags_count = config["MAX_TAGS_COUNT"] max_attempts = config["MAX_ATTEMPS"] wait_time = config["WAIT_TIME"] chatgpt_url = config["CHATGPT_URL"] system_prompt = config["SYSTEM_PROMPT"] sber_gpt = config["SBER_GRT"] use_sber = config["USE_SBER"] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if use_sber: tokenizer = GPT2Tokenizer.from_pretrained(sber_gpt) model = GPT2LMHeadModel.from_pretrained(sber_gpt).to(DEVICE) def generate( model, tok, text, do_sample=True, max_length=10000, repetition_penalty=5.0, top_k=5, top_p=0.95, temperature=1, num_beams=None, no_repeat_ngram_size=3 ): input_ids = tok.encode(text, return_tensors="pt").to(DEVICE) out = model.generate( input_ids.to(DEVICE), max_length=max_length, repetition_penalty=repetition_penalty, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size ) return list(map(tok.decode, out))[0] def get_answer(question: str) -> Dict[str, Any]: if use_sber: content = generate(model, tokenizer, question) return { 'status': True, 'content': content } headers = { 'Content-Type': 'application/json; charset=utf-8' } payload = { 'model': 'gpt-3.5-turbo', 'messages': [ { 'role': 'system', 'content': system_prompt }, { 'role': 'user', 'content': question } ] } try: response = requests.post(chatgpt_url, headers=headers, data=json.dumps(payload)) response.raise_for_status() content = response.json()['choices'][0]['message']['content'] return { 'status': True, 'content': content } except: return { 'status': False } def format_results(results: List[Tuple[str, str]]) -> str: output = '' for i, (question, answer) in enumerate(results): output += f'Question №{i+1}: {question}\n' output += f'Answer: {answer}\n' if i < len(results) - 1: output += '--------------------------------------\n\n' output = output.strip() return output def validate_and_get_tags(tags: str) -> List[str]: if not tags.strip(): raise gr.Error('Validation error. It is necessary to set at least one tag') tags = [tag.strip() for tag in tags.split('\n') if tag.strip()] if len(tags) > max_tags_count: raise gr.Error(f'Validation error. The maximum allowed number of tags is {max_tags_count}.') return tags def validate_and_get_questions(questions: str) -> List[str]: if not questions.strip(): raise gr.Error('Validation error. It is necessary to ask at least one question') questions = [question.strip() for question in questions.split('\n') if question.strip()] if len(questions) > max_questions_count: raise gr.Error(f'Validation error. The maximum allowed number of questions is {max_questions_count}.') return questions def find_answers(tags: str, questions: str, progress=gr.Progress()) -> str: tags = validate_and_get_tags(tags) questions = validate_and_get_questions(questions) print(f'New attempt to get answers. Got {len(tags)} tags and {len(questions)} questions') print(f'Tags: {tags}') print(f'Questions: {questions}') tags_str = ''.join([f'[{tag}]' for tag in tags]) results = [] for question in progress.tqdm(questions): time.sleep(wait_time) tagged_question = f'{tags_str} {question}' for attempt in range(max_attempts): answer = get_answer(tagged_question) if answer['status']: results.append((question, answer['content'])) break elif attempt == max_attempts - 1: results.append((question, 'An error occurred while receiving data.')) else: time.sleep(wait_time) return format_results(results) title = '

AnswerMate

' with gr.Blocks(theme='soft', title='AnswerMate') as blocks: gr.HTML(title) gr.Markdown('The service allows you to get answers to all questions on the specified topic.') with gr.Row(): tags_input = gr.Textbox( label=f'Enter tags (each line is a separate tag). Maximum: {max_tags_count}.', placeholder='.NET\nC#', lines=max_tags_count ) questions_input = gr.Textbox( label=f'Enter questions (each line is a separate question). Maximum: {max_questions_count}.', placeholder='What is inheritance, encapsulation, abstraction, polymorphism?\nWhat is CLR?', lines=max_questions_count ) process_button = gr.Button('Find answers') outputs = gr.Textbox(label='Output', placeholder='Output will appear here') process_button.click(fn=find_answers, inputs=[tags_input, questions_input], outputs=outputs) blocks.queue(concurrency_count=1).launch()