|
from typing import Dict, List, Any |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftConfig, PeftModel |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
config = PeftConfig.from_pretrained(path) |
|
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto') |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
|
|
self.model = PeftModel.from_pretrained(model, path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
prompt (:obj:`str`): |
|
temperature (:obj:`float`, `optional`, defaults to 0.5): |
|
eos_token_id (:obj:`int`, `optional`, defaults to tokenizer.eos_token_id): |
|
early_stopping (:obj:`bool`, `optional`, defaults to `True`): |
|
repetition_penalty (:obj:`float`, `optional`, defaults to 0.3): |
|
Return: |
|
A :obj:`str` : generated sequences |
|
""" |
|
|
|
prompt = data.pop("prompt", None) |
|
temperature = data.pop("temperature", 0.5) |
|
eos_token_id = data.pop("eos_token_id", self.tokenizer.eos_token_id) |
|
early_stopping = data.pop('early_stopping', True) |
|
repetition_penalty = data.pop('repetition_penalty', 0.3) |
|
max_new_tokens = data.pop('max_new_tokens', 100) |
|
|
|
if prompt is None: |
|
raise ValueError("No prompt provided.") |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
prediction = self.model.generate( |
|
**inputs, |
|
temperature=temperature, |
|
eos_token_id=eos_token_id, |
|
early_stopping=early_stopping, |
|
repetition_penalty=repetition_penalty, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
return prediction |
|
|