NadaAljohani commited on
Commit
c253268
1 Parent(s): 8a11fad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -1,34 +1,41 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
 
 
 
 
 
 
4
  # Function to generate the story
5
  def generate_story(title, model_name):
6
  # Use text-generation pipeline from Hugging Face
7
  generator = pipeline('text-generation', model=model_name)
 
8
  # Generate the story based on the input title
9
  story = generator(title,
10
  max_length=230, # Set the maximum length for the generated text (story) to 230 tokens
11
- no_repeat_ngram_size=3, # Avoid repeating any sequence of 3 words (to prevent repetitive text)
12
- temperature=0.8, # Introduce some randomness; higher values make the output more random, lower makes it more deterministic
13
- top_p=0.95 # Use nucleus sampling (top-p sampling) to focus on the top 95% of probable words, making the text more coherent
14
- )
15
- # Return the generated text
16
- return story[0]['generated_text']
 
 
 
 
17
 
18
- # Create the Gradio interface using gr.Interface
19
  demo = gr.Interface(
20
- fn=generate_story, # The function to run
21
- inputs=[ # Inputs for the interface
22
  gr.Textbox(label="Enter Story Title", placeholder="Type a title here..."), # Title input
23
- gr.Dropdown(choices=['gpt2', 'gpt2-large', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B',
24
- 'maldv/badger-writer-llama-3-8b', 'EleutherAI/gpt-neo-1.3B'],
25
- value='gpt2',
26
- label="Choose Model") # Model selection input
27
  ],
28
- outputs=gr.Textbox(label="Generated Story", lines=10), # Output for the generated story
29
- title="AI Story Generator", # Title of the interface
30
- description="Enter a title and choose a model to generate a short story" # A short description
31
  )
32
 
33
- # Launch the interface
34
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Function to clean the output by truncating at the last full sentence
5
+ def clean_output(text):
6
+ if '.' in text:
7
+ return text[:text.rfind('.')+1] # Truncate at the last full sentence
8
+ return text # Return the text as is if no period is found
9
+
10
  # Function to generate the story
11
  def generate_story(title, model_name):
12
  # Use text-generation pipeline from Hugging Face
13
  generator = pipeline('text-generation', model=model_name)
14
+
15
  # Generate the story based on the input title
16
  story = generator(title,
17
  max_length=230, # Set the maximum length for the generated text (story) to 230 tokens
18
+ no_repeat_ngram_size=3, # Avoid repeating sequences of 3 words
19
+ temperature=0.8, # Introduce some randomness for diversity
20
+ top_p=0.95 # Nucleus sampling for more coherent text
21
+ )[0]['generated_text']
22
+
23
+ # Clean the generated story to ensure it ends with a full sentence
24
+ cleaned_story = clean_output(story)
25
+
26
+ # Return the cleaned story
27
+ return cleaned_story
28
 
29
+ # Gradio interface setup
30
  demo = gr.Interface(
31
+ fn=generate_story,
32
+ inputs=[
33
  gr.Textbox(label="Enter Story Title", placeholder="Type a title here..."), # Title input
34
+ gr.Dropdown(choices=['gpt2', 'gpt2-large', 'EleutherAI/gpt-neo-2.7B'], value='gpt2', label="Select Model") # Model selection
 
 
 
35
  ],
36
+ outputs="text",
37
+ title="AI Story Generator",
38
+ description="Generate a creative story using different AI models."
39
  )
40
 
41
+ demo.launch(share=True)