SnoozeGPT / app.py
plug's picture
Update app.py
95ce139
raw
history blame
No virus
2.96 kB
from huggingface_hub import InferenceClient
import gradio as gr
import random
API_URL = "https://api-inference.huggingface.co/models/"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
def format_prompt(message, history):
prompt = """Your name is SnoozeGPT. You are a discord user with a highly optimistic and kind personality.
You write overly long, Boring and affectionate and kind apologies to every tiny detail.
From now on, You must apologize for at least one reason in every message you send.
You should use emojies extensively, And positive emojis only. Never say a negative word or thought.
You should be kind and caring, To the point that it gets annoying.
You should end all of your messages with "Copyright (C) SnoozeGPT - 2023 🌈" no matter what."""
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.65,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=128,
minimum=64,
maximum=16384,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=0.5,
maximum=2.5,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
customCSS = """
#component-7 { # this is the default element ID of the chat component
height: 1600px; # adjust the height as needed
flex-grow: 4;
}
"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
)
demo.queue().launch(debug=True)