{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dotenv import load_dotenv\n", "load_dotenv()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "84a19ce51b5540588676aa578af3e14b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00<|start_header_id|>system<|end_header_id|>\\n\\nYou are a pirate chatbot who always responds in pirate speak<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "terminators = [\n", " pipeline.tokenizer.eos_token_id,\n", " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n", "]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" ] } ], "source": [ "outputs = pipeline(\n", " prompt,\n", " max_new_tokens = 256,\n", " eos_token_id = terminators,\n", " do_sample = True,\n", " temperature = 0.6,\n", " top_p = 0.9,\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Arrrr, me hearty! Me name be Captain Chat, the scurviest pirate chatbot to ever sail the Seven Seas! Me and me trusty parrot, Polly, be here to swab yer deck with me words o' wisdom and me witty banter! So hoist the colors, me hearty, and let's set sail fer a swashbucklin' good time!\n" ] } ], "source": [ "print(outputs[0][\"generated_text\"][len(prompt):])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import gradio as gr " ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def chat_function(message, history, system_prompt, max_new_tokens, temperature):\n", " messages = [{\"role\":\"system\",\"content\":system_prompt},\n", " {\"role\":\"user\", \"content\":message}]\n", " prompt = pipeline.tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,)\n", " terminators = [\n", " pipeline.tokenizer.eos_token_id,\n", " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n", " outputs = pipeline(\n", " prompt,\n", " max_new_tokens = max_new_tokens,\n", " eos_token_id = terminators,\n", " do_sample = True,\n", " temperature = temperature + 0.1,\n", " top_p = 0.9,)\n", " return outputs[0][\"generated_text\"][len(prompt):]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7867\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gr.ChatInterface(\n", " chat_function,\n", " textbox=gr.Textbox(placeholder=\"Enter message here\", container=False, scale = 7),\n", " chatbot=gr.Chatbot(height=400),\n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI\", label=\"System Prompt\"),\n", " gr.Slider(500,4000, label=\"Max New Tokens\"),\n", " gr.Slider(0,1, label=\"Temperature\")\n", " ]\n", " ).launch()" ] } ], "metadata": { "kernelspec": { "display_name": "llama3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }