multimodalart HF staff commited on
Commit
3cf0016
1 Parent(s): c8890b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  import torchvision.transforms as T
6
  import sys
7
  import spaces
 
8
 
9
  subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
10
  os.chdir("HairFastGAN")
@@ -34,20 +35,36 @@ from hair_swap import HairFast, get_parser
34
 
35
  hair_fast = HairFast(get_parser().parse_args([]))
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @spaces.GPU
38
  def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
 
39
  final_image = hair_fast.swap(source, target_1, target_2)
40
  return T.functional.to_pil_image(final_image)
41
 
42
  with gr.Blocks() as demo:
43
- gr.Markdown("Start typing below and then click **Run** to see the output.")
44
  with gr.Row():
45
  source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
46
  target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
47
  target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
48
  btn = gr.Button("Get the haircut")
49
  output = gr.Image(label="Your result")
50
-
 
 
 
51
  btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
52
 
53
  demo.launch()
 
5
  import torchvision.transforms as T
6
  import sys
7
  import spaces
8
+ from PIL import Image
9
 
10
  subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
11
  os.chdir("HairFastGAN")
 
35
 
36
  hair_fast = HairFast(get_parser().parse_args([]))
37
 
38
+ def resize(image_path):
39
+ img = Image.open("image_path")
40
+ square_size = 1024
41
+
42
+ left = (img.width - square_size) / 2
43
+ top = (img.height - square_size) / 2
44
+ right = (img.width + square_size) / 2
45
+ bottom = (img.height + square_size) / 2
46
+
47
+ img_cropped = img.crop((left, top, right, bottom))
48
+ return img_cropped
49
+
50
  @spaces.GPU
51
  def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
52
+ target_2 = target_2 if target_2 else target_1
53
  final_image = hair_fast.swap(source, target_1, target_2)
54
  return T.functional.to_pil_image(final_image)
55
 
56
  with gr.Blocks() as demo:
57
+ gr.Markdown("## HairFastGan")
58
  with gr.Row():
59
  source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
60
  target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
61
  target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
62
  btn = gr.Button("Get the haircut")
63
  output = gr.Image(label="Your result")
64
+ gr.Examples(examples=[("michael_cera-min.png", "leo_square-min.png", "pink_hair_celeb-min.png")])
65
+ source.upload(fn=resize, input=source, output=source)
66
+ target_1.upload(fn=resize, input=target_1, output=target_1)
67
+ target_2.upload(fn=resize, input=target_2, output=target_2)
68
  btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
69
 
70
  demo.launch()