prithivMLmods commited on
Commit
9efc887
1 Parent(s): 2f57ec0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -13
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
4
 
5
  import spaces
6
  from diffusers import DiffusionPipeline
@@ -47,6 +49,15 @@ style_list = [
47
  STYLE_NAMES = [style["name"] for style in style_list]
48
  DEFAULT_STYLE_NAME = STYLE_NAMES[0]
49
 
 
 
 
 
 
 
 
 
 
50
  @spaces.GPU
51
  def infer(
52
  prompt,
@@ -58,6 +69,7 @@ def infer(
58
  guidance_scale=0.0,
59
  num_inference_steps=4,
60
  style="Style Zero",
 
61
  progress=gr.Progress(track_tqdm=True),
62
  ):
63
 
@@ -70,17 +82,33 @@ def infer(
70
 
71
  generator = torch.Generator().manual_seed(seed)
72
 
73
- image = pipe(
74
- prompt=styled_prompt,
75
- negative_prompt=styled_negative_prompt,
76
- guidance_scale=guidance_scale,
77
- num_inference_steps=num_inference_steps,
78
- width=width,
79
- height=height,
80
- generator=generator,
81
- ).images[0]
82
-
83
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  examples = [
86
  "A capybara wearing a suit holding a sign that reads Hello World",
@@ -96,7 +124,7 @@ css = """
96
  with gr.Blocks(css=css) as demo:
97
  with gr.Column(elem_id="col-container"):
98
  gr.Markdown(" # [Stable Diffusion 3.5 Large Turbo (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo)")
99
- gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) to run locally with ComfyUI or diffusers.")
100
 
101
  with gr.Row():
102
  prompt = gr.Text(
@@ -112,6 +140,13 @@ with gr.Blocks(css=css) as demo:
112
  result = gr.Image(label="Result", show_label=False)
113
 
114
  with gr.Row(visible=True):
 
 
 
 
 
 
 
115
  style_selection = gr.Radio(
116
  show_label=True,
117
  container=True,
@@ -138,7 +173,7 @@ with gr.Blocks(css=css) as demo:
138
  )
139
 
140
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
141
-
142
  with gr.Row():
143
  width = gr.Slider(
144
  label="Width",
@@ -173,6 +208,8 @@ with gr.Blocks(css=css) as demo:
173
  value=4,
174
  )
175
 
 
 
176
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
177
 
178
  gr.on(
@@ -188,6 +225,7 @@ with gr.Blocks(css=css) as demo:
188
  guidance_scale,
189
  num_inference_steps,
190
  style_selection,
 
191
  ],
192
  outputs=[result, seed],
193
  )
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import uuid
5
+ from PIL import Image
6
 
7
  import spaces
8
  from diffusers import DiffusionPipeline
 
49
  STYLE_NAMES = [style["name"] for style in style_list]
50
  DEFAULT_STYLE_NAME = STYLE_NAMES[0]
51
 
52
+ grid_sizes = {
53
+ "2x1": (2, 1),
54
+ "1x2": (1, 2),
55
+ "2x2": (2, 2),
56
+ "2x3": (2, 3),
57
+ "3x2": (3, 2),
58
+ "1x1": (1, 1)
59
+ }
60
+
61
  @spaces.GPU
62
  def infer(
63
  prompt,
 
69
  guidance_scale=0.0,
70
  num_inference_steps=4,
71
  style="Style Zero",
72
+ grid_size="1x1",
73
  progress=gr.Progress(track_tqdm=True),
74
  ):
75
 
 
82
 
83
  generator = torch.Generator().manual_seed(seed)
84
 
85
+ grid_size_x, grid_size_y = grid_sizes.get(grid_size, (2, 2))
86
+ num_images = grid_size_x * grid_size_y
87
+
88
+ images = []
89
+ for _ in range(num_images):
90
+ image = pipe(
91
+ prompt=styled_prompt,
92
+ negative_prompt=styled_negative_prompt,
93
+ guidance_scale=guidance_scale,
94
+ num_inference_steps=num_inference_steps,
95
+ width=width,
96
+ height=height,
97
+ generator=generator,
98
+ ).images[0]
99
+ images.append(image)
100
+
101
+ # Create a grid image
102
+ grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
103
+
104
+ for i, img in enumerate(images[:num_images]):
105
+ grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
106
+
107
+ # Save the grid image
108
+ unique_name = str(uuid.uuid4()) + ".png"
109
+ grid_img.save(unique_name)
110
+
111
+ return unique_name, seed
112
 
113
  examples = [
114
  "A capybara wearing a suit holding a sign that reads Hello World",
 
124
  with gr.Blocks(css=css) as demo:
125
  with gr.Column(elem_id="col-container"):
126
  gr.Markdown(" # [Stable Diffusion 3.5 Large Turbo (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo)")
127
+ gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series.")
128
 
129
  with gr.Row():
130
  prompt = gr.Text(
 
140
  result = gr.Image(label="Result", show_label=False)
141
 
142
  with gr.Row(visible=True):
143
+ grid_size_selection = gr.Dropdown(
144
+ choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
145
+ value="1x1",
146
+ label="Grid Size"
147
+ )
148
+
149
+ with gr.Row(visible=True):
150
  style_selection = gr.Radio(
151
  show_label=True,
152
  container=True,
 
173
  )
174
 
175
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
176
+
177
  with gr.Row():
178
  width = gr.Slider(
179
  label="Width",
 
208
  value=4,
209
  )
210
 
211
+
212
+
213
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
214
 
215
  gr.on(
 
225
  guidance_scale,
226
  num_inference_steps,
227
  style_selection,
228
+ grid_size_selection,
229
  ],
230
  outputs=[result, seed],
231
  )