nroggendorff commited on
Commit
fa8a9aa
1 Parent(s): a3ed834

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
4
+
5
+ import re
6
+
7
+ import spaces
8
+
9
+ torch.set_default_device("cuda")
10
+
11
+ model_id = "glides/llama-eap"
12
+ pipe = pipeline("text-generation", model=model_id, device_map="auto")
13
+
14
+ with open("sys", "r") as f:
15
+ system_prompt = f.read()
16
+
17
+ def follows_rules(s):
18
+ pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>'
19
+ return bool(re.match(pattern, s.replace("\n", "")))
20
+
21
+ @spaces.GPU(duration=120)
22
+ def predict(input_text, history):
23
+ chat = [{"role": "system", "content": system_prompt}]
24
+ for item in history:
25
+ chat.append({"role": "user", "content": item[0]})
26
+ if item[1] is not None:
27
+ chat.append({"role": "assistant", "content": item[1]})
28
+ chat.append({"role": "user", "content": input_text})
29
+
30
+ generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
31
+
32
+ removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1]
33
+ removed_posts = removed_pres.split("</refined>")[0] + "</refined>"
34
+
35
+ while not follows_rules(removed_posts):
36
+ print(f"model output {generated_text} was found invalid")
37
+ generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
38
+
39
+ model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "")
40
+
41
+ return model_output
42
+
43
+ gr.ChatInterface(predict, theme="soft").launch()