File size: 1,538 Bytes
88e1248 dc32044 c8b5fa1 dc32044 88e1248 dc32044 c8b5fa1 39ec5b7 c8b5fa1 88e1248 c8b5fa1 dc32044 c8b5fa1 dc32044 c8b5fa1 f85d258 dc32044 adf79f2 dc32044 adf79f2 dc32044 6dec8ee dc32044 adf79f2 dc32044 adf79f2 dc32044 88e1248 dc32044 6dec8ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from typing import Dict, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt and generation parameters.
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
prompt = data.pop("inputs", None)
parameters = data.pop("parameters", None)
if prompt is None:
raise ValueError("Missing prompt.")
# Preprocess
inputs = self.tokenizer(prompt, return_tensors="pt")
# Forward
LOGGER.info(f"Start generation.")
if parameters is not None:
output = self.model.generate(**inputs, **parameters)
else:
output = self.model.generate(**inputs)
# Postprocess
prediction = self.tokenizer.decode(output[0])
LOGGER.info(f"Generated text: {prediction}")
return {"generated_text": prediction}
|