File size: 4,210 Bytes
e8bac0f
 
 
 
 
44568a0
e8bac0f
 
 
 
 
 
cde7a7b
 
e8bac0f
 
cde7a7b
 
 
 
 
 
bcac619
83746e4
 
 
6ab04f4
83746e4
 
 
cde7a7b
83746e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8bcde6
83746e4
 
 
6bda5d8
bcac619
83746e4
7826a10
a8032bb
 
 
cde7a7b
 
 
 
 
 
 
a8032bb
7826a10
 
a8032bb
 
 
 
 
 
 
 
 
 
 
 
 
cde7a7b
a8032bb
 
 
 
 
 
44568a0
 
 
 
 
 
 
a8032bb
a6549b1
 
 
 
 
 
 
 
 
 
a8032bb
a6549b1
 
 
 
 
 
 
 
83746e4
 
cde7a7b
44568a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import aiohttp
import os
import json
from collections import deque
import asyncio

TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")

if not TOKEN:
    raise ValueError("API token is not set. Please set the HUGGINGFACE_API_TOKEN environment variable.")

print(f"API Token: {TOKEN[:5]}...{TOKEN[-5:]}")  # Check API token

memory = deque(maxlen=10)

async def test_api():
    headers = {"Authorization": f"Bearer {TOKEN}"}
    async with aiohttp.ClientSession() as session:
        async with session.get("https://api-inference.huggingface.co/models/mistralai/Mistral-Nemo-Instruct-2407", headers=headers) as response:
            print(f"Test API response: {await response.text()}")

async def respond(
    message,
    history: list[tuple[str, str]],
    system_message="AI Assistant Role",
    max_tokens=512,
    temperature=0.7,
    top_p=0.95,
):
    system_prefix = "System: Respond in the same language as the input (English, Korean, Chinese, Japanese, etc.)."
    full_system_message = f"{system_prefix}{system_message}"

    memory.append((message, None))
    messages = [{"role": "system", "content": full_system_message}]
    for val in memory:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    headers = {
        "Authorization": f"Bearer {TOKEN}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "mistralai/Mistral-Nemo-Instruct-2407",
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "messages": messages,
        "stream": True
    }

    try:
        async with aiohttp.ClientSession() as session:
            async with session.post("https://api-inference.huggingface.co/v1/chat/completions", headers=headers, json=payload) as response:
                print(f"Response status: {response.status}")
                if response.status != 200:
                    error_text = await response.text()
                    print(f"Error response: {error_text}")
                    yield "An API response error occurred. Please try again."
                    return

                response_text = ""
                async for chunk in response.content:
                    if chunk:
                        try:
                            chunk_data = chunk.decode('utf-8')
                            response_json = json.loads(chunk_data)
                            if "choices" in response_json:
                                content = response_json["choices"][0]["message"]["content"]
                                response_text += content
                                yield response_text
                        except json.JSONDecodeError:
                            continue
                
                if not response_text:
                    yield "I apologize, but I couldn't generate a response. Please try again."
    except Exception as e:
        print(f"Exception occurred: {str(e)}")
        yield f"An error occurred: {str(e)}"

    memory[-1] = (message, response_text)

async def chat(message, history, system_message, max_tokens, temperature, top_p):
    response = ""
    try:
        async for chunk in respond(message, history, system_message, max_tokens, temperature, top_p):
            response = chunk
            yield response
    except Exception as e:
        print(f"Chat function error: {str(e)}")
        yield f"An error occurred in the chat function: {str(e)}"

theme = "Nymbo/Nymbo_Theme"

css = """
footer {
    visibility: hidden;
}
"""

demo = gr.ChatInterface(
    css=css,
    fn=chat,
    theme=theme,
    additional_inputs=[
        gr.Textbox(value="AI Assistant Role", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ]
)

if __name__ == "__main__":
    asyncio.run(test_api())  # Run API test
    demo.queue().launch(max_threads=20, show_error=True)