AkiKagura commited on
Commit
3fd8421
1 Parent(s): 962e235

add txt2img

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -8,7 +8,7 @@ from io import BytesIO
8
  import os
9
  MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
10
 
11
- #from diffusers import StableDiffusionPipeline
12
  from diffusers import StableDiffusionImg2ImgPipeline
13
 
14
  def empty_checker(images, **kwargs): return images, False
@@ -19,10 +19,16 @@ YOUR_TOKEN=MY_SECRET_TOKEN
19
 
20
  device="cpu"
21
 
 
22
  img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("AkiKagura/mkgen-diffusion", duse_auth_token=YOUR_TOKEN)
23
  img_pipe.safety_checker = empty_checker
24
  img_pipe.to(device)
25
 
 
 
 
 
 
26
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
27
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
28
 
@@ -42,7 +48,10 @@ def infer(source_img, prompt, guide, steps, seed, strength):
42
  source_image = resize(512, source_img)
43
  source_image.save('source.png')
44
 
45
- images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps)
 
 
 
46
  images = []
47
 
48
  for i, image in enumerate(images_list["images"]):
 
8
  import os
9
  MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
10
 
11
+ from diffusers import StableDiffusionPipeline
12
  from diffusers import StableDiffusionImg2ImgPipeline
13
 
14
  def empty_checker(images, **kwargs): return images, False
 
19
 
20
  device="cpu"
21
 
22
+ # img2img pipeline
23
  img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("AkiKagura/mkgen-diffusion", duse_auth_token=YOUR_TOKEN)
24
  img_pipe.safety_checker = empty_checker
25
  img_pipe.to(device)
26
 
27
+ # txt2img pipeline
28
+ pipe = StableDiffusionPipeline.from_pretrained("AkiKagura/mkgen-diffusion", duse_auth_token=YOUR_TOKEN)
29
+ pipe.safety_checker = empty_checker
30
+ pipe.to(device)
31
+
32
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
33
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
34
 
 
48
  source_image = resize(512, source_img)
49
  source_image.save('source.png')
50
 
51
+ if source_image is None:
52
+ images_list = pipe([prompt] * 1, guidance_scale=guide, num_inference_steps=steps)
53
+ else:
54
+ images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps)
55
  images = []
56
 
57
  for i, image in enumerate(images_list["images"]):