File size: 2,890 Bytes
e8bac0f
2db0d53
e8bac0f
a5a2931
04e9db1
e8bac0f
04e9db1
 
a5a2931
7ad1fa3
e8bac0f
a5a2931
 
 
2db0d53
83746e4
 
2db0d53
 
 
 
83746e4
a5a2931
 
2db0d53
a5a2931
 
85f74eb
 
a5a2931
85f74eb
a5a2931
2db0d53
cde7a7b
a5a2931
85f74eb
a5a2931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04e9db1
 
 
 
 
 
 
 
85f74eb
04e9db1
a5a2931
04e9db1
a6549b1
 
 
 
 
 
 
a5a2931
a6549b1
a5a2931
a6549b1
a5a2931
2db0d53
a6549b1
2db0d53
a5a2931
2db0d53
 
a5a2931
 
a6549b1
83746e4
 
a5a2931
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from huggingface_hub import InferenceClient
import os
from gradio_client import Client  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ
import logging

# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face API ํ† ํฐ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))

# ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
client = Client("http://211.233.58.202:7960/")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    system_prefix = "System: ์ž…๋ ฅ์–ด์˜ ์–ธ์–ด(์˜์–ด, ํ•œ๊ตญ์–ด, ์ค‘๊ตญ์–ด, ์ผ๋ณธ์–ด ๋“ฑ)์— ๋”ฐ๋ผ ๋™์ผํ•œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋ผ."
    full_system_message = f"{system_prefix}{system_message}"

    messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
    # ์ด์ „ ๋Œ€ํ™” ๋‚ด์—ญ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    # ํ˜„์žฌ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
    messages.append({"role": "user", "content": message})

    # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ์„ API์— ์ „์†กํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
    try:
        result = client.predict(
            prompt=message,
            seed=123,
            randomize_seed=False,
            width=1024,
            height=576,
            guidance_scale=5,
            num_inference_steps=28,
            api_name="/infer_t2i"
        )
        # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ URL์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
        if 'url' in result:
            return result['url']
        else:
            return "์ด๋ฏธ์ง€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."

    try:
        result = client.predict(...)
        if 'url' in result:
            return result['url']
        else:
            logging.error("์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: %s", result.get('error', '์•Œ ์ˆ˜ ์—†๋Š” ์˜ค๋ฅ˜'))
            return "์ด๋ฏธ์ง€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
    except Exception as e:
        logging.error("API ์š”์ฒญ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: %s", str(e))
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
        
theme = "Nymbo/Nymbo_Theme"
css = """
footer {
    visibility: hidden;
}
"""

# Gradio ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are an AI assistant.", label="System Prompt"),
        gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
        ),
    ],
    theme=theme,
    css=css
)

if __name__ == "__main__":
    demo.launch()