Shabbir-Anjum commited on
Commit
6f473bc
1 Parent(s): 5a44408

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,26 +1,31 @@
1
- import torch
2
  import streamlit as st
3
  from diffusers import StableDiffusion3Pipeline
4
 
5
- # Load the model
6
  pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
7
- pipeline = pipeline.to("cuda") # Move model to GPU if available
8
 
9
- # Streamlit UI
 
 
 
 
10
  def main():
11
- st.title("Stable Diffusion 3 Medium Demo")
12
- prompt = st.text_input("Enter your prompt:", "A cat holding a sign that says hello world")
13
-
14
- if st.button("Generate Image"):
15
- with st.spinner("Generating..."):
16
- try:
17
- image = pipeline(prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
18
- st.image(image, caption="Generated Image", use_column_width=True)
19
- except Exception as e:
20
- st.error(f"Error: {e}")
21
-
22
- if __name__ == "__main__":
23
- main()
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
 
 
1
  import streamlit as st
2
  from diffusers import StableDiffusion3Pipeline
3
 
4
+ # Load the Diffusion pipeline
5
  pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
 
6
 
7
+ def generate_prompt(prompt_text):
8
+ # Generate response using the Diffusion model
9
+ response = pipeline(prompt_text, top_p=0.9, max_length=100)[0]['generated_text']
10
+ return response
11
+
12
  def main():
13
+ st.title('Diffusion Model Prompt Generator')
14
+
15
+ # Text input for the prompt
16
+ prompt_text = st.text_area("Enter your prompt here:", height=200)
 
 
 
 
 
 
 
 
 
17
 
18
+ # Button to generate prompt
19
+ if st.button("Generate"):
20
+ if prompt_text:
21
+ with st.spinner('Generating...'):
22
+ generated_text = generate_prompt(prompt_text)
23
+ st.success('Generation complete!')
24
+ st.text_area('Generated Text:', value=generated_text, height=400)
25
+ else:
26
+ st.warning('Please enter a prompt.')
27
+
28
+ if __name__ == '__main__':
29
+ main()
30
 
31