think / app.py
nroggendorff's picture
Update app.py
f00c92a verified
raw
history blame contribute delete
No virus
1.45 kB
import gradio as gr
from transformers import pipeline
import torch
import re
import os
import spaces
torch.set_default_device("cuda")
model_id = "glides/mistral-eap"
pipe = pipeline("text-generation", model=model_id, device_map="auto")
system_prompt = os.environ["sys"]
def follows_rules(s):
pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>'
return bool(re.match(pattern, s.replace("\n", "")))
@spaces.GPU(duration=120)
def predict(input_text, history):
chat = [{"role": "system", "content": system_prompt}]
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": input_text})
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1]
removed_posts = removed_pres.split("</refined>")[0] + "</refined>"
while not follows_rules(removed_posts):
print(f"model output {generated_text} was found invalid")
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "")
return model_output.strip()
gr.ChatInterface(predict, theme="soft").launch()