File size: 3,180 Bytes
b021ace
 
 
3a1adec
b021ace
efc9ce3
b021ace
3cf0016
b021ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c21edc5
 
 
 
 
 
 
 
 
 
d6331eb
29b6ebb
b021ace
 
 
3cf0016
6997827
29b6ebb
 
 
3cf0016
29b6ebb
3cf0016
b021ace
c8890b6
3cf0016
c8890b6
b021ace
 
 
3cf0016
7a82ab9
b021ace
7a82ab9
 
 
 
 
 
 
 
 
0241ca9
033cd3f
 
 
293a1c7
7a82ab9
 
 
 
 
 
 
 
 
 
b021ace
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import subprocess
import shutil
import os
import gradio as gr
import torchvision.transforms as T
import sys
import spaces
from PIL import Image

subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
os.chdir("HairFastGAN")

subprocess.run(["git", "clone", "https://huggingface.co/AIRI-Institute/HairFastGAN"], check=True)

os.chdir("HairFastGAN")
subprocess.run(["git", "lfs", "pull"], check=True)
os.chdir("..")

shutil.move("HairFastGAN/pretrained_models", "pretrained_models")
shutil.move("HairFastGAN/input", "input")

shutil.rmtree("HairFastGAN")

items = os.listdir()

for item in items:
    print(item)
    shutil.move(item, os.path.join('..', item))

os.chdir("..")

shutil.rmtree("HairFastGAN")

from hair_swap import HairFast, get_parser
from utils.shape_predictor import align_face

hair_fast = HairFast(get_parser().parse_args([]))

def resize(image_path):
    img = Image.open(image_path)

    if img.size != (1024, 1024):
        img = align_face(img, return_tensors=False)[0]
    
    return img

@spaces.GPU
def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
    target_2 = target_2 if target_2 else target_1
    final_image = hair_fast.swap(source, target_1, target_2)
    return T.functional.to_pil_image(final_image)
    
with gr.Blocks() as demo:
    gr.Markdown("## HairFastGan")
    gr.Markdown("Gradio demo for [AIRI Institute](https://github.com/AIRI-Institute)'s HairFastGan: [Paper](https://huggingface.co/papers/2404.01094) | [GitHub](https://github.com/AIRI-Institute/HairFastGAN) | [Weights 🤗](https://huggingface.co/AIRI-Institute/HairFastGAN) | [Colab](https://colab.research.google.com/#fileId=https%3A//huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb)")
    with gr.Row():
        with gr.Column():
            with gr.Row():
                source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
                target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
            with gr.Accordion("Reference hair color", open=False):  
              target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
            btn = gr.Button("Get the haircut")
        with gr.Column():
            output = gr.Image(label="Your result")
    gr.Examples(examples=[["michael_cera-min.png", "leo_square-min.png", "pink_hair_celeb-min.png"]], inputs=[source, target_1, target_2], outputs=output)
    source.upload(fn=resize, inputs=source, outputs=source)
    target_1.upload(fn=resize, inputs=target_1, outputs=target_1)
    target_2.upload(fn=resize, inputs=target_2, outputs=target_2)
    btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
    gr.Markdown('''To cite the paper by the authors
```
    @article{nikolaev2024hairfastgan,
      title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
      author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
      journal={arXiv preprint arXiv:2404.01094},
      year={2024}
    }
```
    ''')

demo.launch()