File size: 2,459 Bytes
e8bac0f
2db0d53
e8bac0f
a5a2931
e8bac0f
a5a2931
7ad1fa3
e8bac0f
a5a2931
 
 
2db0d53
83746e4
 
2db0d53
 
 
 
83746e4
a5a2931
 
2db0d53
a5a2931
 
85f74eb
 
a5a2931
85f74eb
a5a2931
2db0d53
cde7a7b
a5a2931
85f74eb
a5a2931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f74eb
a5a2931
a8032bb
a6549b1
 
 
 
 
 
 
a5a2931
a6549b1
a5a2931
a6549b1
a5a2931
2db0d53
a6549b1
2db0d53
a5a2931
2db0d53
 
a5a2931
 
a6549b1
83746e4
 
a5a2931
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from huggingface_hub import InferenceClient
import os
from gradio_client import Client  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face API ํ† ํฐ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))

# ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
client = Client("http://211.233.58.202:7960/")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    system_prefix = "System: ์ž…๋ ฅ์–ด์˜ ์–ธ์–ด(์˜์–ด, ํ•œ๊ตญ์–ด, ์ค‘๊ตญ์–ด, ์ผ๋ณธ์–ด ๋“ฑ)์— ๋”ฐ๋ผ ๋™์ผํ•œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋ผ."
    full_system_message = f"{system_prefix}{system_message}"

    messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
    # ์ด์ „ ๋Œ€ํ™” ๋‚ด์—ญ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    # ํ˜„์žฌ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
    messages.append({"role": "user", "content": message})

    # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ์„ API์— ์ „์†กํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
    try:
        result = client.predict(
            prompt=message,
            seed=123,
            randomize_seed=False,
            width=1024,
            height=576,
            guidance_scale=5,
            num_inference_steps=28,
            api_name="/infer_t2i"
        )
        # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ URL์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
        if 'url' in result:
            return result['url']
        else:
            return "์ด๋ฏธ์ง€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
    except Exception as e:
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"

theme = "Nymbo/Nymbo_Theme"
css = """
footer {
    visibility: hidden;
}
"""

# Gradio ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are an AI assistant.", label="System Prompt"),
        gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
        ),
    ],
    theme=theme,
    css=css
)

if __name__ == "__main__":
    demo.launch()