abhishek's picture
abhishek HF staff
adapter merger
c445497
raw
history blame contribute delete
No virus
1.84 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
def merge(base_model, trained_adapter, token):
base = AutoModelForCausalLM.from_pretrained(
base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=token
)
model = PeftModel.from_pretrained(base, trained_adapter, token=token)
try:
tokenizer = AutoTokenizer.from_pretrained(base_model, token=token)
except RecursionError:
tokenizer = AutoTokenizer.from_pretrained(
base_model, unk_token="<unk>", token=token
)
model = model.merge_and_unload()
print("Saving target model")
model.push_to_hub(trained_adapter, token=token)
tokenizer.push_to_hub(trained_adapter, token=token)
return gr.Markdown.update(
value="Model successfully merged and pushed! Please shutdown/pause this space"
)
with gr.Blocks() as demo:
gr.Markdown("## AutoTrain Merge Adapter")
gr.Markdown("Please duplicate this space and attach a GPU in order to use it.")
token = gr.Textbox(
label="Hugging Face Write Token",
value="",
lines=1,
max_lines=1,
interactive=True,
type="password",
)
base_model = gr.Textbox(
label="Base Model (e.g. meta-llama/Llama-2-7b-chat-hf)",
value="",
lines=1,
max_lines=1,
interactive=True,
)
trained_adapter = gr.Textbox(
label="Trained Adapter Model (e.g. username/autotrain-my-llama)",
value="",
lines=1,
max_lines=1,
interactive=True,
)
submit = gr.Button(value="Merge & Push")
op = gr.Markdown(interactive=False)
submit.click(merge, inputs=[base_model, trained_adapter, token], outputs=[op])
if __name__ == "__main__":
demo.launch()