multimodalart HF staff commited on
Commit
b021ace
1 Parent(s): 8e92572

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import shutil
3
+ import os
4
+ from gradio import gr
5
+ import torchvision.transforms as T
6
+ import spaces
7
+
8
+ subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
9
+ os.chdir("HairFastGAN")
10
+
11
+ subprocess.run(["git", "clone", "https://huggingface.co/AIRI-Institute/HairFastGAN"], check=True)
12
+
13
+ os.chdir("HairFastGAN")
14
+ subprocess.run(["git", "lfs", "pull"], check=True)
15
+ os.chdir("..")
16
+
17
+ shutil.move("HairFastGAN/pretrained_models", "pretrained_models")
18
+ shutil.move("HairFastGAN/input", "input")
19
+
20
+ shutil.rmtree("HairFastGAN")
21
+
22
+ from hair_swap import HairFast, get_parser
23
+
24
+ hair_fast = HairFast(get_parser().parse_args([]))
25
+
26
+ @spaces.GPU
27
+ def swap_hair(source, target_1, target_2):
28
+ result = hair_fast(face_img, shape_img, color_img)
29
+ final_image = hair_fast.swap(face_path, shape_path, color_path)
30
+ return T.functional.to_pil_image(final_image)
31
+
32
+ with gr.Blocks() as demo:
33
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
34
+ with gr.Row():
35
+ source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
36
+ target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
37
+ target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
38
+ btn = gr.Button("Get the haircut")
39
+ output = gr.Image(label="Your result")
40
+
41
+ btn.click(fn=update, inputs=inp, outputs=out)
42
+
43
+ demo.launch()