adarksky's picture
-- update
9e76bbc verified
raw
history blame contribute delete
No virus
1.4 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "adarksky/president-gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def respond(
message,
history,
min_length,
max_length,
temperature,
top_k,
):
print(message)
# Tokenize the input
input_ids = tokenizer.encode(message, return_tensors="pt")
# Generate the response
with torch.no_grad():
outputs = model.generate(
input_ids,
min_length=min_length,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
temperature=temperature,
top_k=top_k
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(generated_text)
return generated_text
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=50, maximum=150, value=100, step=1, label="Min length"),
gr.Slider(minimum=200, maximum=1000, value=250, step=1, label="Max length"),
gr.Slider(minimum=1, maximum=1.9, value=1.2, step=0.1, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
],
)
if __name__ == "__main__":
demo.launch()