Nebula / merge.py
Kqte's picture
Upload 12 files
64407ae verified
raw
history blame contribute delete
No virus
1.83 kB
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
from peft import LoraConfig, PeftModel
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
"/path/to/meta-llama3-8b",
#low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map=device_map,
)
model = PeftModel.from_pretrained(model, "/path/to/llama3-8b-adapter", device_map=device_map)
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b", trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id + 1
model.config.pad_token_id = tokenizer.pad_token_id
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=4096, do_sample=False)
print("Padding side:",tokenizer.padding_side)
val_dataset = load_dataset("csv", data_files={'val':'/path/to/actseq-val-new.csv'})["val"]
test_dataset = load_dataset("csv", data_files={'test':'/path/to/actseq-test-new.csv'})["test"]
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['dial_with_actions'])):
text = f"<|begin_of_text|>Predict the action sequence (AS) for the Minecraft excerpt:\n {example['dial_with_actions'][i]}\n ### AS:"
output_texts.append(text)
return output_texts
val_texts = formatting_prompts_func(val_dataset)
test_texts = formatting_prompts_func(test_dataset)
print("Val Length:", len(val_texts))
print("Test Length:", len(test_texts))
f = open("/path/to/val-output-file","w")
for text in val_texts:
print(text)
print(pipe(text)[0]["generated_text"], file=f)
f.close()
f = open("/path/to/test-output-file","w")
for text in test_texts:
print(text)
print(pipe(text)[0]["generated_text"], file=f)
f.close()