Brasd99 commited on
Commit
903851d
1 Parent(s): 5ea94b0

Переход на Sage

Browse files
Files changed (3) hide show
  1. app.py +11 -32
  2. config.json +2 -2
  3. requirements.txt +2 -1
app.py CHANGED
@@ -3,8 +3,7 @@ import time
3
  import json
4
  import requests
5
  import gradio as gr
6
- import torch
7
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
8
 
9
  with open("config.json", "r") as f:
10
  config = json.load(f)
@@ -15,41 +14,21 @@ max_attempts = config["MAX_ATTEMPS"]
15
  wait_time = config["WAIT_TIME"]
16
  chatgpt_url = config["CHATGPT_URL"]
17
  system_prompt = config["SYSTEM_PROMPT"]
18
- sber_gpt = config["SBER_GRT"]
19
- use_sber = config["USE_SBER"]
20
-
21
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
-
23
- if use_sber:
24
- tokenizer = GPT2Tokenizer.from_pretrained(sber_gpt)
25
- model = GPT2LMHeadModel.from_pretrained(sber_gpt).to(DEVICE)
26
-
27
- def generate(
28
- model, tok, text,
29
- do_sample=True, max_length=10000, repetition_penalty=5.0,
30
- top_k=5, top_p=0.95, temperature=1,
31
- num_beams=10,
32
- no_repeat_ngram_size=3
33
- ):
34
- input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)
35
- out = model.generate(
36
- input_ids.to(DEVICE),
37
- max_length=max_length,
38
- repetition_penalty=repetition_penalty,
39
- do_sample=do_sample,
40
- top_k=top_k, top_p=top_p, temperature=temperature,
41
- num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
42
- )
43
- return list(map(tok.decode, out))[0]
44
 
45
  def get_answer(question: str) -> Dict[str, Any]:
46
- if use_sber:
47
- content = generate(model, tokenizer, question)
 
 
48
  return {
49
  'status': True,
50
- 'content': content
51
  }
52
-
53
  headers = {
54
  'Content-Type': 'application/json; charset=utf-8'
55
  }
 
3
  import json
4
  import requests
5
  import gradio as gr
6
+ import poe
 
7
 
8
  with open("config.json", "r") as f:
9
  config = json.load(f)
 
14
  wait_time = config["WAIT_TIME"]
15
  chatgpt_url = config["CHATGPT_URL"]
16
  system_prompt = config["SYSTEM_PROMPT"]
17
+ use_sage = config["USE_SAGE"]
18
+ sage_token = config["SAGE_TOKEN"]
19
+
20
+ client = poe.Client(sage_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def get_answer(question: str) -> Dict[str, Any]:
23
+
24
+ if use_sage:
25
+ for chunk in client.send_message("capybara", question, with_chat_break=True):
26
+ pass
27
  return {
28
  'status': True,
29
+ 'content': chunk["text"]
30
  }
31
+
32
  headers = {
33
  'Content-Type': 'application/json; charset=utf-8'
34
  }
config.json CHANGED
@@ -4,7 +4,7 @@
4
  "MAX_ATTEMPS": 5,
5
  "WAIT_TIME": 1,
6
  "CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
7
- "SBER_GRT": "ai-forever/rugpt3small_based_on_gpt2",
8
- "USE_SBER": 0,
9
  "SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
10
  }
 
4
  "MAX_ATTEMPS": 5,
5
  "WAIT_TIME": 1,
6
  "CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
7
+ "USE_SAGE": 1,
8
+ "SAGE_TOKEN": "PGUXiyEZKRHcMoij9AjxXw%3D%3D",
9
  "SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
10
  }
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch
2
- transformers
 
 
1
  torch
2
+ transformers
3
+ poe-api