import torch import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import os model_name = 'eliolio/bart-finetuned-yelpreviews' access_token = os.environ.get('private_token') model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) def create_prompt(stars, useful, funny, cool): return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}" def postprocess(review): dot = review.rfind('.') return review[:dot] def generate_reviews(stars, useful, funny, cool): text = create_prompt(stars, useful, funny, cool) inputs = tokenizer(text, return_tensors='pt') out = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, do_sample=True, num_return_sequences=3, temperature=1.2, top_p=0.9 ) reviews = [] for review in out: reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True))) return reviews[0], reviews[1], reviews[2] css = """ #ctr {text-align: center;} #btn {color: white; background: linear-gradient( 90deg, rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );} """ md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐""" demo = gr.Blocks(css=css) with demo: with gr.Row(): gr.Markdown(md_text, elem_id='ctr') with gr.Row(): stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars") useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful") funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny") cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool") with gr.Row(): button = gr.Button("Generate reviews !", elem_id='btn') with gr.Row(): output1 = gr.Textbox(label="Review #1") output2 = gr.Textbox(label="Review #2") output3 = gr.Textbox(label="Review #3") button.click( fn=generate_reviews, inputs=[stars, useful, funny, cool], outputs=[output1, output2, output3] ) demo.launch()