ml27b / model.py
jaymojnidar's picture
adding my data config peft
e4aabb2
raw
history blame contribute delete
No virus
3.51 kB
import os
from threading import Thread
from typing import Iterator
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login
# model_id = 'meta-llama/Llama-2-13b-chat-hf'
#model_id = 'meta-llama/Llama-2-7b-chat-hf'
#model_id = 'Trelis/Llama-2-7b-chat-hf-sharded-bf16'
model_id = 'jaymojnidar/llama2-finetuned-mydata'
config_model_id = 'jaymojnidar/llama2-finetuned-mydata/adapter_config.json'
model_type = 'PEFT'
if torch.cuda.is_available():
tok = os.environ['HF_TOKEN']
login(new_session=True,
write_permission=False,
token=tok
#, token="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"
)
if model_type == 'PEFT':
config = PeftConfig.from_pretrained("jaymojnidar/llama2-finetuned-mydata")
model = AutoModelForCausalLM.from_pretrained("Trelis/Llama-2-7b-chat-hf-sharded-bf16")
model = PeftModel.from_pretrained(model, "jaymojnidar/llama2-finetuned-mydata")
else:
config = AutoConfig.from_pretrained(model_id, use_auth_token=True)
config.pretraining_tp = 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch.float16,
#load_in_4bit=True,
device_map='auto',
use_auth_token=True
)
print("Loaded the model!")
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(model_id)
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50) -> Iterator[str]:
prompt = get_prompt(message, chat_history, system_prompt)
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda') #.to(torch.device) #
streamer = TextIteratorStreamer(tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
print(f"output text ->{text}<- end of text")
outputs.append(text)
yield ''.join(outputs)