const crypto = require('crypto'); const TextStream = require('./TextStream'); const { RecursiveCharacterTextSplitter } = require('langchain/text_splitter'); const { ChatOpenAI } = require('langchain/chat_models/openai'); const { loadSummarizationChain } = require('langchain/chains'); const { refinePrompt } = require('./prompts/refinePrompt'); const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models'); class BaseClient { constructor(apiKey, options = {}) { this.apiKey = apiKey; this.sender = options.sender || 'AI'; this.contextStrategy = null; this.currentDateString = new Date().toLocaleDateString('en-us', { year: 'numeric', month: 'long', day: 'numeric', }); } setOptions() { throw new Error('Method \'setOptions\' must be implemented.'); } getCompletion() { throw new Error('Method \'getCompletion\' must be implemented.'); } async sendCompletion() { throw new Error('Method \'sendCompletion\' must be implemented.'); } getSaveOptions() { throw new Error('Subclasses must implement getSaveOptions'); } async buildMessages() { throw new Error('Subclasses must implement buildMessages'); } getBuildMessagesOptions() { throw new Error('Subclasses must implement getBuildMessagesOptions'); } async generateTextStream(text, onProgress, options = {}) { const stream = new TextStream(text, options); await stream.processTextStream(onProgress); } async setMessageOptions(opts = {}) { if (opts && typeof opts === 'object') { this.setOptions(opts); } const user = opts.user || null; const conversationId = opts.conversationId || crypto.randomUUID(); const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); const responseMessageId = crypto.randomUUID(); const saveOptions = this.getSaveOptions(); this.abortController = opts.abortController || new AbortController(); this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? []; return { ...opts, user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions, }; } createUserMessage({ messageId, parentMessageId, conversationId, text }) { const userMessage = { messageId, parentMessageId, conversationId, sender: 'User', text, isCreatedByUser: true, }; return userMessage; } async handleStartMethods(message, opts) { const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } = await this.setMessageOptions(opts); const userMessage = this.createUserMessage({ messageId: userMessageId, parentMessageId, conversationId, text: message, }); if (typeof opts?.getIds === 'function') { opts.getIds({ userMessage, conversationId, responseMessageId, }); } if (typeof opts?.onStart === 'function') { opts.onStart(userMessage); } return { ...opts, user, conversationId, responseMessageId, saveOptions, userMessage, }; } addInstructions(messages, instructions) { const payload = []; if (!instructions) { return messages; } if (messages.length > 1) { payload.push(...messages.slice(0, -1)); } payload.push(instructions); if (messages.length > 0) { payload.push(messages[messages.length - 1]); } return payload; } async handleTokenCountMap(tokenCountMap) { if (this.currentMessages.length === 0) { return; } for (let i = 0; i < this.currentMessages.length; i++) { // Skip the last message, which is the user message. if (i === this.currentMessages.length - 1) { break; } const message = this.currentMessages[i]; const { messageId } = message; const update = {}; if (messageId === tokenCountMap.refined?.messageId) { if (this.options.debug) { console.debug(`Adding refined props to ${messageId}.`); } update.refinedMessageText = tokenCountMap.refined.content; update.refinedTokenCount = tokenCountMap.refined.tokenCount; } if (message.tokenCount && !update.refinedTokenCount) { if (this.options.debug) { console.debug(`Skipping ${messageId}: already had a token count.`); } continue; } const tokenCount = tokenCountMap[messageId]; if (tokenCount) { message.tokenCount = tokenCount; update.tokenCount = tokenCount; await this.updateMessageInDatabase({ messageId, ...update }); } } } concatenateMessages(messages) { return messages.reduce((acc, message) => { const nameOrRole = message.name ?? message.role; return acc + `${nameOrRole}:\n${message.content}\n\n`; }, ''); } async refineMessages(messagesToRefine, remainingContextTokens) { const model = new ChatOpenAI({ temperature: 0 }); const chain = loadSummarizationChain(model, { type: 'refine', verbose: this.options.debug, refinePrompt, }); const splitter = new RecursiveCharacterTextSplitter({ chunkSize: 1500, chunkOverlap: 100, }); const userMessages = this.concatenateMessages( messagesToRefine.filter((m) => m.role === 'user'), ); const assistantMessages = this.concatenateMessages( messagesToRefine.filter((m) => m.role !== 'user'), ); const userDocs = await splitter.createDocuments([userMessages], [], { chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n', appendChunkOverlapHeader: true, }); const assistantDocs = await splitter.createDocuments([assistantMessages], [], { chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n', appendChunkOverlapHeader: true, }); // const chunkSize = Math.round(concatenatedMessages.length / 512); const input_documents = userDocs.concat(assistantDocs); if (this.options.debug) { console.debug('Refining messages...'); } try { const res = await chain.call({ input_documents, signal: this.abortController.signal, }); const refinedMessage = { role: 'assistant', content: res.output_text, tokenCount: this.getTokenCount(res.output_text), }; if (this.options.debug) { console.debug('Refined messages', refinedMessage); console.debug( `remainingContextTokens: ${remainingContextTokens}, after refining: ${ remainingContextTokens - refinedMessage.tokenCount }`, ); } return refinedMessage; } catch (e) { console.error('Error refining messages'); console.error(e); return null; } } /** * This method processes an array of messages and returns a context of messages that fit within a token limit. * It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached. * If the token limit would be exceeded by adding a message, that message and possibly the previous one are added to a separate array of messages to refine. * The method uses `push` and `pop` operations for efficient array manipulation, and reverses the arrays at the end to maintain the original order of the messages. * The method also includes a mechanism to avoid blocking the event loop by waiting for the next tick after each iteration. * * @param {Array} messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest. * @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. */ async getMessagesWithinTokenLimit(messages) { let currentTokenCount = 0; let context = []; let messagesToRefine = []; let refineIndex = -1; let remainingContextTokens = this.maxContextTokens; for (let i = messages.length - 1; i >= 0; i--) { const message = messages[i]; const newTokenCount = currentTokenCount + message.tokenCount; const exceededLimit = newTokenCount > this.maxContextTokens; let shouldRefine = exceededLimit && this.shouldRefineContext; let refineNextMessage = i !== 0 && i !== 1 && context.length > 0; if (shouldRefine) { messagesToRefine.push(message); if (refineIndex === -1) { refineIndex = i; } if (refineNextMessage) { refineIndex = i + 1; const removedMessage = context.pop(); messagesToRefine.push(removedMessage); currentTokenCount -= removedMessage.tokenCount; remainingContextTokens = this.maxContextTokens - currentTokenCount; refineNextMessage = false; } continue; } else if (exceededLimit) { break; } context.push(message); currentTokenCount = newTokenCount; remainingContextTokens = this.maxContextTokens - currentTokenCount; await new Promise((resolve) => setImmediate(resolve)); } return { context: context.reverse(), remainingContextTokens, messagesToRefine: messagesToRefine.reverse(), refineIndex, }; } async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) { let payload = this.addInstructions(formattedMessages, instructions); let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); let { context, remainingContextTokens, messagesToRefine, refineIndex } = await this.getMessagesWithinTokenLimit(payload); payload = context; let refinedMessage; // if (messagesToRefine.length > 0) { // refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens); // payload.unshift(refinedMessage); // remainingContextTokens -= refinedMessage.tokenCount; // } // if (remainingContextTokens <= instructions?.tokenCount) { // if (this.options.debug) { // console.debug(`Remaining context (${remainingContextTokens}) is less than instructions token count: ${instructions.tokenCount}`); // } // ({ context, remainingContextTokens, messagesToRefine, refineIndex } = await this.getMessagesWithinTokenLimit(payload)); // payload = context; // } // Calculate the difference in length to determine how many messages were discarded if any let diff = orderedWithInstructions.length - payload.length; if (this.options.debug) { console.debug('<---------------------------------DIFF--------------------------------->'); console.debug( `Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`, ); console.debug( 'remainingContextTokens, this.maxContextTokens (1/2)', remainingContextTokens, this.maxContextTokens, ); } // If the difference is positive, slice the orderedWithInstructions array if (diff > 0) { orderedWithInstructions = orderedWithInstructions.slice(diff); } if (messagesToRefine.length > 0) { refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens); payload.unshift(refinedMessage); remainingContextTokens -= refinedMessage.tokenCount; } if (this.options.debug) { console.debug( 'remainingContextTokens, this.maxContextTokens (2/2)', remainingContextTokens, this.maxContextTokens, ); } let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { if (!message.messageId) { return map; } if (index === refineIndex) { map.refined = { ...refinedMessage, messageId: message.messageId }; } map[message.messageId] = payload[index].tokenCount; return map; }, {}); const promptTokens = this.maxContextTokens - remainingContextTokens; if (this.options.debug) { console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->'); console.debug('Payload:', payload); console.debug('Token Count Map:', tokenCountMap); console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens); } return { payload, tokenCountMap, promptTokens, messages: orderedWithInstructions }; } async sendMessage(message, opts = {}) { const { user, conversationId, responseMessageId, saveOptions, userMessage } = await this.handleStartMethods(message, opts); this.user = user; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages this.currentMessages.push(userMessage); let { prompt: payload, tokenCountMap, promptTokens, } = await this.buildMessages( this.currentMessages, // When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId. // this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation userMessage.messageId, this.getBuildMessagesOptions(opts), ); if (this.options.debug) { console.debug('payload'); console.debug(payload); } if (tokenCountMap) { console.dir(tokenCountMap, { depth: null }); if (tokenCountMap[userMessage.messageId]) { userMessage.tokenCount = tokenCountMap[userMessage.messageId]; console.log('userMessage.tokenCount', userMessage.tokenCount); console.log('userMessage', userMessage); } payload = payload.map((message) => { const messageWithoutTokenCount = message; delete messageWithoutTokenCount.tokenCount; return messageWithoutTokenCount; }); this.handleTokenCountMap(tokenCountMap); } await this.saveMessageToDatabase(userMessage, saveOptions, user); const responseMessage = { messageId: responseMessageId, conversationId, parentMessageId: userMessage.messageId, isCreatedByUser: false, model: this.modelOptions.model, sender: this.sender, text: await this.sendCompletion(payload, opts), promptTokens, }; if (tokenCountMap && this.getTokenCountForResponse) { responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); responseMessage.completionTokens = responseMessage.tokenCount; } await this.saveMessageToDatabase(responseMessage, saveOptions, user); delete responseMessage.tokenCount; return responseMessage; } async getConversation(conversationId, user = null) { return await getConvo(user, conversationId); } async loadHistory(conversationId, parentMessageId = null) { if (this.options.debug) { console.debug('Loading history for conversation', conversationId, parentMessageId); } const messages = (await getMessages({ conversationId })) || []; if (messages.length === 0) { return []; } let mapMethod = null; if (this.getMessageMapMethod) { mapMethod = this.getMessageMapMethod(); } return this.constructor.getMessagesForConversation(messages, parentMessageId, mapMethod); } async saveMessageToDatabase(message, endpointOptions, user = null) { await saveMessage({ ...message, unfinished: false, cancelled: false }); await saveConvo(user, { conversationId: message.conversationId, endpoint: this.options.endpoint, ...endpointOptions, }); } async updateMessageInDatabase(message) { await updateMessage(message); } /** * Iterate through messages, building an array based on the parentMessageId. * Each message has an id and a parentMessageId. The parentMessageId is the id of the message that this message is a reply to. * @param messages * @param parentMessageId * @returns {*[]} An array containing the messages in the order they should be displayed, starting with the root message. */ static getMessagesForConversation(messages, parentMessageId, mapMethod = null) { if (!messages || messages.length === 0) { return []; } const orderedMessages = []; let currentMessageId = parentMessageId; while (currentMessageId) { const message = messages.find((msg) => { const messageId = msg.messageId ?? msg.id; return messageId === currentMessageId; }); if (!message) { break; } orderedMessages.unshift(message); currentMessageId = message.parentMessageId; } if (mapMethod) { return orderedMessages.map(mapMethod); } return orderedMessages; } /** * Algorithm adapted from "6. Counting tokens for chat API calls" of * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb * * An additional 2 tokens need to be added for metadata after all messages have been counted. * * @param {*} message */ getTokenCountForMessage(message) { let tokensPerMessage; let nameAdjustment; if (this.modelOptions.model.startsWith('gpt-4')) { tokensPerMessage = 3; nameAdjustment = 1; } else { tokensPerMessage = 4; nameAdjustment = -1; } if (this.options.debug) { console.debug('getTokenCountForMessage', message); } // Map each property of the message to the number of tokens it contains const propertyTokenCounts = Object.entries(message).map(([key, value]) => { if (key === 'tokenCount' || typeof value !== 'string') { return 0; } // Count the number of tokens in the property value const numTokens = this.getTokenCount(value); // Adjust by `nameAdjustment` tokens if the property key is 'name' const adjustment = key === 'name' ? nameAdjustment : 0; return numTokens + adjustment; }); if (this.options.debug) { console.debug('propertyTokenCounts', propertyTokenCounts); } // Sum the number of tokens in all properties and add `tokensPerMessage` for metadata return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage); } } module.exports = BaseClient;