PoetryChat / src /chatglm.py
Tsumugii24
initial commit
f7161fa
raw
history blame contribute delete
No virus
3.07 kB
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
"""
import platform
from loguru import logger
from src.base_model import BaseLLMModel
from src.presets import LOCAL_MODELS
class ChatGLMClient(BaseLLMModel):
def __init__(self, model_name, user_name=""):
super().__init__(model_name=model_name, user=user_name)
import torch
from transformers import AutoModel, AutoTokenizer
system_name = platform.system()
logger.info(f"Loading model from {model_name}")
if model_name in LOCAL_MODELS:
model_path = LOCAL_MODELS[model_name]
else:
model_path = model_name
self.CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
quantified = False
if "int4" in model_name:
quantified = True
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map='auto', torch_dtype='auto')
if torch.cuda.is_available():
logger.info("CUDA is available, using CUDA")
model = model.half().cuda()
# mps加速还存在一些问题,暂时不使用
elif system_name == "Darwin" and model_path is not None and not quantified:
logger.info("Running on macOS, using MPS")
# running on macOS and model already downloaded
model = model.half().to("mps")
else:
logger.info("GPU is not available, using CPU")
model = model.float()
model = model.eval()
logger.info(f"Model loaded from {model_path}")
self.CHATGLM_MODEL = model
def _get_glm3_style_input(self):
history = self.history
query = history.pop()["content"]
return history, query
def _get_glm2_style_input(self):
history = [x["content"] for x in self.history]
query = history.pop()
logger.debug(f"{history}")
assert len(history) % 2 == 0, f"History should be even length. current history is: {history}"
history = [[history[i], history[i + 1]]
for i in range(0, len(history), 2)]
return history, query
def _get_glm_style_input(self):
if "glm2" in self.model_name:
return self._get_glm2_style_input()
else:
return self._get_glm3_style_input()
def get_answer_at_once(self):
history, query = self._get_glm_style_input()
logger.debug(f"{history}")
response, _ = self.CHATGLM_MODEL.chat(
self.CHATGLM_TOKENIZER, query, history=history)
return response, len(response)
def get_answer_stream_iter(self):
history, query = self._get_glm_style_input()
logger.debug(f"{history}")
for response, history in self.CHATGLM_MODEL.stream_chat(
self.CHATGLM_TOKENIZER,
query,
history,
max_length=self.token_upper_limit,
top_p=self.top_p,
temperature=self.temperature,
):
yield response