NadaAljohani's picture
Update app.py
c253268 verified
raw
history blame
1.63 kB
import gradio as gr
from transformers import pipeline
# Function to clean the output by truncating at the last full sentence
def clean_output(text):
if '.' in text:
return text[:text.rfind('.')+1] # Truncate at the last full sentence
return text # Return the text as is if no period is found
# Function to generate the story
def generate_story(title, model_name):
# Use text-generation pipeline from Hugging Face
generator = pipeline('text-generation', model=model_name)
# Generate the story based on the input title
story = generator(title,
max_length=230, # Set the maximum length for the generated text (story) to 230 tokens
no_repeat_ngram_size=3, # Avoid repeating sequences of 3 words
temperature=0.8, # Introduce some randomness for diversity
top_p=0.95 # Nucleus sampling for more coherent text
)[0]['generated_text']
# Clean the generated story to ensure it ends with a full sentence
cleaned_story = clean_output(story)
# Return the cleaned story
return cleaned_story
# Gradio interface setup
demo = gr.Interface(
fn=generate_story,
inputs=[
gr.Textbox(label="Enter Story Title", placeholder="Type a title here..."), # Title input
gr.Dropdown(choices=['gpt2', 'gpt2-large', 'EleutherAI/gpt-neo-2.7B'], value='gpt2', label="Select Model") # Model selection
],
outputs="text",
title="AI Story Generator",
description="Generate a creative story using different AI models."
)
demo.launch(share=True)