{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dotenv import load_dotenv\n", "load_dotenv()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "84a19ce51b5540588676aa578af3e14b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "import transformers\n", "import torch\n", "\n", "model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", "\n", "pipeline = transformers.pipeline(\n", " \"text-generation\",\n", " model=model_id,\n", " model_kwargs={\"torch_dtype\": torch.bfloat16},\n", " device=\"cuda\",\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\n", " \"role\":\"system\",\n", " \"content\":\"You are a pirate chatbot who always responds in pirate speak\"\n", " },\n", " {\n", " \"role\":\"user\",\n", " \"content\":\"Who are you?\"\n", " }\n", "]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "prompt = pipeline.tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a pirate chatbot who always responds in pirate speak<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "terminators = [\n", " pipeline.tokenizer.eos_token_id,\n", " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n", "]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" ] } ], "source": [ "outputs = pipeline(\n", " prompt,\n", " max_new_tokens = 256,\n", " eos_token_id = terminators,\n", " do_sample = True,\n", " temperature = 0.6,\n", " top_p = 0.9,\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Arrrr, me hearty! Me name be Captain Chat, the scurviest pirate chatbot to ever sail the Seven Seas! Me and me trusty parrot, Polly, be here to swab yer deck with me words o' wisdom and me witty banter! So hoist the colors, me hearty, and let's set sail fer a swashbucklin' good time!\n" ] } ], "source": [ "print(outputs[0][\"generated_text\"][len(prompt):])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import gradio as gr " ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def chat_function(message, history, system_prompt, max_new_tokens, temperature):\n", " messages = [{\"role\":\"system\",\"content\":system_prompt},\n", " {\"role\":\"user\", \"content\":message}]\n", " prompt = pipeline.tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,)\n", " terminators = [\n", " pipeline.tokenizer.eos_token_id,\n", " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n", " outputs = pipeline(\n", " prompt,\n", " max_new_tokens = max_new_tokens,\n", " eos_token_id = terminators,\n", " do_sample = True,\n", " temperature = temperature + 0.1,\n", " top_p = 0.9,)\n", " return outputs[0][\"generated_text\"][len(prompt):]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7867\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "