File size: 3,147 Bytes
431af92
760bfde
d5d52b4
431af92
760bfde
54b4787
760bfde
 
 
 
 
 
 
 
d5d52b4
760bfde
 
 
d5d52b4
8f4de93
760bfde
 
 
 
 
b9fc382
760bfde
d5d52b4
760bfde
 
 
 
 
b9fc382
760bfde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5d52b4
760bfde
d5d52b4
 
 
760bfde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5d52b4
760bfde
 
 
 
 
 
 
 
de84b65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import subprocess
from huggingface_hub import HfApi, hf_hub_download
import gradio as gr

subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])

def error_str(error, title="Error"):
    return f"""#### {title}
            {error}"""

def url_to_model_id(model_id_str):
    return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
    
def get_ckpt_names(model_id = "nitrosocke/mo-di-diffusion"):
    
    if model_id == "":
        return error_str("Please enter a model name.", title="Invalid input"), None, None

    try:
        api = HfApi()
        ckpt_files = [f for f in api.list_repo_files(url_to_model_id(model_id)) if f.endswith(".ckpt")]
        
        if len(ckpt_files) == 0:
            return error_str("No checkpoint files found in the model repo."), None, None
        
        return None, gr.update(choices=ckpt_files, visible=True), gr.update(visible=True)
        
    except Exception as e:
        return error_str(e), None, None

def convert(model_id, ckpt_name, token = "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"):
    
    model_id = url_to_model_id(model_id)

    # 1. Download the checkpoint file
    ckpt_path = hf_hub_download(repo_id=model_id, filename=ckpt_name)

    # 2. Run the conversion script
    subprocess.run(
        [
            "python3",
            "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
            "--checkpoint_path",
            ckpt_path,
            "--dump_path" ,
            model_id,
        ]
    )
    
    # list files in current directory and return them as a list:
    import os
    return f"""files in current directory:
    {[f for f in os.listdir(".") if os.path.isfile(f)]}"""

    
with gr.Blocks() as demo:

    with gr.Row():

        with gr.Column(scale=11):
            with gr.Group():
                gr.Markdown("## 1. Load model info")
                input_token = gr.Textbox(
                    max_lines=1,
                    label="Hugging Face token",
                    placeholder="hf_...",
                )
                gr.Markdown("Get your token [here](https://huggingface.co/settings/tokens).")
                input_model = gr.Textbox(
                    max_lines=1,
                    label="Model name or URL",
                    placeholder="username/model_name",
                )

            btn_get_ckpts = gr.Button("Load")

        with gr.Column(scale=10, visible=False) as col_convert:
            gr.Markdown("## 2. Convert to Diffusers🧨")
            radio_ckpts = gr.Radio(label="Choose a checkpoint to convert", visible=False)
            btn_convert = gr.Button("Convert")

    error_output = gr.Markdown(label="Output")
    btn_get_ckpts.click(
        fn=get_ckpt_names,
        inputs=[input_model],
        outputs=[error_output, radio_ckpts, col_convert],
        scroll_to_output=True
    )

    btn_convert.click(
        fn=convert,
        inputs=[input_model, radio_ckpts, input_token],
        outputs=error_output,
        scroll_to_output=True
    )

demo.launch()