File size: 1,658 Bytes
88e1248 dc32044 c8b5fa1 7bf309f c8b5fa1 dc32044 88e1248 7bf309f dc32044 c8b5fa1 39ec5b7 c8b5fa1 88e1248 c8b5fa1 dc32044 c8b5fa1 dc32044 c8b5fa1 f85d258 dc32044 adf79f2 dc32044 7bf309f dc32044 6dec8ee dc32044 7bf309f dc32044 7bf309f dc32044 88e1248 dc32044 6dec8ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from typing import Dict, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch.cuda
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Forward
LOGGER.info(f"Start generation.")
if parameters is not None:
output = self.model.generate(input_ids=input_ids, **parameters)
else:
output = self.model.generate(input_ids=input_ids)
# Postprocess
prediction = self.tokenizer.decode(output[0])
LOGGER.info(f"Generated text: {prediction}")
return {"generated_text": prediction}
|