import os import json import requests import logging from streamlit.logger import get_logger from models.custom_parsers import CustomStringOutputParser from app_config import ENDPOINT_NAMES from langchain.chains import ConversationChain from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain.prompts import PromptTemplate from typing import Any, List, Mapping, Optional, Dict logger = get_logger(__name__) class DatabricksCustomBizLLM(LLM): issue:str language:str temperature:float = 0.8 max_tokens: int = 128 db_url:str headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'} @property def _llm_type(self) -> str: return "custom_databricks_biz" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: data_ = {'inputs': { 'prompt': [prompt], 'issue': [self.issue], 'language': [self.language], 'temperature': [self.temperature], 'max_tokens': [self.max_tokens], }} data_json = json.dumps(data_, allow_nan=True) response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json) if response.status_code != 200: raise Exception(f'Request failed with status {response.status_code}, {response.text}') return response.json()["predictions"][0]["generated_text"] _DATABRICKS_TEMPLATE_ = """{history} helper: {input} texter:""" def get_databricks_biz_chain(source, issue, language, memory, temperature=0.8): PROMPT = PromptTemplate( input_variables=['history', 'input'], template=_DATABRICKS_TEMPLATE_ ) llm = DatabricksCustomBizLLM( issue=issue, language=language, temperature=temperature, max_tokens=256, db_url = os.environ['DATABRICKS_URL'].format(endpoint_name=ENDPOINT_NAMES.get(source, "conversation_simulator")) ) llm_chain = ConversationChain( llm=llm, prompt=PROMPT, memory=memory, output_parser=CustomStringOutputParser() ) logging.debug(f"loaded Databricks Biz model") return llm_chain, "helper:"