AnswerMate / app.py
Brasd99's picture
Bug fix
ad9e41c
raw
history blame contribute delete
No virus
4.83 kB
from typing import List, Tuple, Dict, Any
import time
import json
import requests
import gradio as gr
import poe
import os
with open('config.json', 'r') as f:
config = json.load(f)
max_questions_count = config['MAX_QUESTIONS_COUNT']
max_tags_count = config['MAX_TAGS_COUNT']
max_attempts = config['MAX_ATTEMPS']
wait_time = config['WAIT_TIME']
chatgpt_url = config['CHATGPT_URL']
system_prompt = config['SYSTEM_PROMPT']
use_sage = config['USE_SAGE']
sage_token = os.environ['SAGE_TOKEN']
def get_answer(question: str, client: poe.Client=None) -> Dict[str, Any]:
if use_sage:
for chunk in client.send_message('capybara', question, with_chat_break=True):
pass
client.delete_message(chunk['messageId'])
return {
'status': True,
'content': chunk['text']
}
headers = {
'Content-Type': 'application/json; charset=utf-8'
}
payload = {
'model': 'gpt-3.5-turbo',
'messages': [
{
'role': 'system',
'content': system_prompt
},
{
'role': 'user',
'content': question
}
]
}
try:
response = requests.post(chatgpt_url, headers=headers, data=json.dumps(payload))
response.raise_for_status()
content = response.json()['choices'][0]['message']['content']
return {
'status': True,
'content': content
}
except:
return {
'status': False
}
def format_results(results: List[Tuple[str, str]]) -> str:
output = ''
for i, (question, answer) in enumerate(results):
output += f'Question №{i+1}: {question}\n'
output += f'Answer: {answer}\n'
if i < len(results) - 1:
output += '--------------------------------------\n\n'
output = output.strip()
return output
def validate_and_get_tags(tags: str) -> List[str]:
if not tags.strip():
raise gr.Error('Validation error. It is necessary to set at least one tag')
tags = [tag.strip() for tag in tags.split('\n') if tag.strip()]
if len(tags) > max_tags_count:
raise gr.Error(f'Validation error. The maximum allowed number of tags is {max_tags_count}.')
return tags
def validate_and_get_questions(questions: str) -> List[str]:
if not questions.strip():
raise gr.Error('Validation error. It is necessary to ask at least one question')
questions = [question.strip() for question in questions.split('\n') if question.strip()]
if len(questions) > max_questions_count:
raise gr.Error(f'Validation error. The maximum allowed number of questions is {max_questions_count}.')
return questions
def find_answers(tags: str, questions: str, progress=gr.Progress()) -> str:
tags = validate_and_get_tags(tags)
questions = validate_and_get_questions(questions)
print(f'New attempt to get answers. Got {len(tags)} tags and {len(questions)} questions')
print(f'Tags: {tags}')
print(f'Questions: {questions}')
tags_str = ''.join([f'[{tag}]' for tag in tags])
if use_sage:
client = poe.Client(sage_token)
results = []
for question in progress.tqdm(questions):
time.sleep(wait_time)
tagged_question = f'{tags_str} {question}'
for attempt in range(max_attempts):
answer = get_answer(tagged_question, client)
if answer['status']:
results.append((question, answer['content']))
break
elif attempt == max_attempts - 1:
results.append((question, 'An error occurred while receiving data.'))
else:
time.sleep(wait_time)
return format_results(results)
title = '<h1 style="text-align:center">AnswerMate</h1>'
with gr.Blocks(theme='soft', title='AnswerMate') as blocks:
gr.HTML(title)
gr.Markdown('The service allows you to get answers to all questions on the specified topic.')
with gr.Row():
tags_input = gr.Textbox(
label=f'Enter tags (each line is a separate tag). Maximum: {max_tags_count}.',
placeholder='.NET\nC#',
lines=max_tags_count
)
questions_input = gr.Textbox(
label=f'Enter questions (each line is a separate question). Maximum: {max_questions_count}.',
placeholder='What is inheritance, encapsulation, abstraction, polymorphism?\nWhat is CLR?',
lines=max_questions_count
)
process_button = gr.Button('Find answers')
outputs = gr.Textbox(label='Output', placeholder='Output will appear here')
process_button.click(fn=find_answers, inputs=[tags_input, questions_input], outputs=outputs)
blocks.queue(concurrency_count=1).launch()