Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
|
12 |
)
|
13 |
|
14 |
|
15 |
-
def generate_image(prompt: str, inference_steps: int =
|
16 |
rng = jax.random.PRNGKey(int(prng_seed))
|
17 |
rng = jax.random.split(rng, jax.device_count())
|
18 |
p_params = replicate(pipeline_params)
|
@@ -20,6 +20,8 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
|
20 |
num_samples = 1
|
21 |
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
|
22 |
prompt_ids = shard(prompt_ids)
|
|
|
|
|
23 |
|
24 |
images = pipeline(
|
25 |
prompt_ids=prompt_ids,
|
@@ -28,6 +30,8 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
|
28 |
height=128,
|
29 |
width=128,
|
30 |
num_inference_steps=int(inference_steps),
|
|
|
|
|
31 |
jit=True,
|
32 |
).images
|
33 |
|
@@ -267,12 +271,15 @@ with block:
|
|
267 |
minimum=1, maximum=100, default=25, step=1, label="Inference Steps"
|
268 |
)
|
269 |
seed_input = gr.inputs.Number(default=0, label="Seed")
|
|
|
|
|
|
|
270 |
|
271 |
-
ex = gr.Examples(examples=[["A watercolor painting of a bird", 25, 0],["A watercolor painting of an otter",25,0],["Marvel MCU deadpool, red mask, red shirt, red gloves, black shoulders, black elbow pads, black legs, gold buckle, black belt, black mask, white eyes, black boots, fuji low light color 35mm film, downtown Osaka alley at night out of focus in background, neon lights",25,0]], fn=generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input
|
272 |
ex.dataset.headers = [""]
|
273 |
-
negative.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
|
274 |
-
prompt_input.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
|
275 |
-
btn.click(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
|
276 |
|
277 |
#advanced_button.click(
|
278 |
# None,
|
|
|
12 |
)
|
13 |
|
14 |
|
15 |
+
def generate_image(prompt: str,negative_prompt:str , inference_steps: int = 25, prng_seed: int = 0, guidance_scale: float = 9):
|
16 |
rng = jax.random.PRNGKey(int(prng_seed))
|
17 |
rng = jax.random.split(rng, jax.device_count())
|
18 |
p_params = replicate(pipeline_params)
|
|
|
20 |
num_samples = 1
|
21 |
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
|
22 |
prompt_ids = shard(prompt_ids)
|
23 |
+
neg_prompt_ids = pipeline.prepare_inputs([negative_prompt] * num_samples)
|
24 |
+
neg_prompt_ids = shard(neg_prompt_ids)
|
25 |
|
26 |
images = pipeline(
|
27 |
prompt_ids=prompt_ids,
|
|
|
30 |
height=128,
|
31 |
width=128,
|
32 |
num_inference_steps=int(inference_steps),
|
33 |
+
neg_prompt_ids=neg_prompt_ids,
|
34 |
+
guidance_scale =float(guidance_scale),
|
35 |
jit=True,
|
36 |
).images
|
37 |
|
|
|
271 |
minimum=1, maximum=100, default=25, step=1, label="Inference Steps"
|
272 |
)
|
273 |
seed_input = gr.inputs.Number(default=0, label="Seed")
|
274 |
+
guidance_scale = gr.Slider(
|
275 |
+
label="Guidance Scale", minimum=0, maximum=50, value=9, step=0.1
|
276 |
+
)
|
277 |
|
278 |
+
ex = gr.Examples(examples=[["A watercolor painting of a bird","mountain", 25, 0,9],["A watercolor painting of an otter","mountain",25,0,9],["Marvel MCU deadpool, red mask, red shirt, red gloves, black shoulders, black elbow pads, black legs, gold buckle, black belt, black mask, white eyes, black boots, fuji low light color 35mm film, downtown Osaka alley at night out of focus in background, neon lights","mountain",25,0,10]], fn=generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale],outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
|
279 |
ex.dataset.headers = [""]
|
280 |
+
negative.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
|
281 |
+
prompt_input.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
|
282 |
+
btn.click(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
|
283 |
|
284 |
#advanced_button.click(
|
285 |
# None,
|