mishig HF staff commited on
Commit
eec5743
2 Parent(s): af961f3 812e078

Fix system prompts (#43)

Browse files
src/lib/components/InferencePlayground/InferencePlayground.svelte CHANGED
@@ -20,16 +20,17 @@
20
 
21
  export let models: ModelEntryWithTokenizer[];
22
 
23
- const startMessage: ChatCompletionInputMessage = { role: "user", content: "" };
 
24
 
25
  let conversation: Conversation = {
26
  model: models[0],
27
  config: defaultGenerationConfig,
28
- messages: [{ ...startMessage }],
 
29
  streaming: true,
30
  };
31
 
32
- let systemMessage: ChatCompletionInputMessage = { role: "system", content: "" };
33
  let hfToken: string | undefined = import.meta.env.VITE_HF_TOKEN;
34
  let viewCode = false;
35
  let showTokenModal = false;
@@ -39,11 +40,6 @@
39
  let waitForNonStreaming = true;
40
 
41
  $: systemPromptSupported = isSystemPromptSupported(conversation.model);
42
- $: {
43
- if (!systemPromptSupported) {
44
- systemMessage = { role: "system", content: "" };
45
- }
46
- }
47
 
48
  function addMessage() {
49
  conversation.messages = [
@@ -61,8 +57,8 @@
61
  }
62
 
63
  function reset() {
64
- systemMessage.content = "";
65
- conversation.messages = [{ ...startMessage }];
66
  }
67
 
68
  function abort() {
@@ -98,12 +94,11 @@
98
  conversation.messages = [...conversation.messages];
99
  }
100
  },
101
- abortController,
102
- systemMessage
103
  );
104
  } else {
105
  waitForNonStreaming = true;
106
- const newMessage = await handleNonStreamingResponse(hf, conversation, systemMessage);
107
  // check if the user did not abort the request
108
  if (waitForNonStreaming) {
109
  conversation.messages = [...conversation.messages, newMessage];
@@ -162,7 +157,8 @@
162
  placeholder={systemPromptSupported
163
  ? "Enter a custom prompt"
164
  : "System prompt is not supported with the chosen model."}
165
- bind:value={systemMessage.content}
 
166
  class="absolute inset-x-0 bottom-0 h-full resize-none bg-transparent px-3 pt-10 text-sm outline-none"
167
  ></textarea>
168
  </div>
 
20
 
21
  export let models: ModelEntryWithTokenizer[];
22
 
23
+ const startMessageUser: ChatCompletionInputMessage = { role: "user", content: "" };
24
+ const startMessageSystem: ChatCompletionInputMessage = { role: "system", content: "" };
25
 
26
  let conversation: Conversation = {
27
  model: models[0],
28
  config: defaultGenerationConfig,
29
+ messages: [{ ...startMessageUser }],
30
+ systemMessage: startMessageSystem,
31
  streaming: true,
32
  };
33
 
 
34
  let hfToken: string | undefined = import.meta.env.VITE_HF_TOKEN;
35
  let viewCode = false;
36
  let showTokenModal = false;
 
40
  let waitForNonStreaming = true;
41
 
42
  $: systemPromptSupported = isSystemPromptSupported(conversation.model);
 
 
 
 
 
43
 
44
  function addMessage() {
45
  conversation.messages = [
 
57
  }
58
 
59
  function reset() {
60
+ conversation.systemMessage.content = "";
61
+ conversation.messages = [{ ...startMessageUser }];
62
  }
63
 
64
  function abort() {
 
94
  conversation.messages = [...conversation.messages];
95
  }
96
  },
97
+ abortController
 
98
  );
99
  } else {
100
  waitForNonStreaming = true;
101
+ const newMessage = await handleNonStreamingResponse(hf, conversation);
102
  // check if the user did not abort the request
103
  if (waitForNonStreaming) {
104
  conversation.messages = [...conversation.messages, newMessage];
 
157
  placeholder={systemPromptSupported
158
  ? "Enter a custom prompt"
159
  : "System prompt is not supported with the chosen model."}
160
+ value={systemPromptSupported ? conversation.systemMessage.content : ""}
161
+ on:input={e => (conversation.systemMessage.content = e.currentTarget.value)}
162
  class="absolute inset-x-0 bottom-0 h-full resize-none bg-transparent px-3 pt-10 text-sm outline-none"
163
  ></textarea>
164
  </div>
src/lib/components/InferencePlayground/InferencePlaygroundCodeSnippets.svelte CHANGED
@@ -8,6 +8,7 @@
8
  import http from "highlight.js/lib/languages/http";
9
 
10
  import IconCopyCode from "../Icons/IconCopyCode.svelte";
 
11
 
12
  hljs.registerLanguage("javascript", javascript);
13
  hljs.registerLanguage("python", python);
@@ -46,10 +47,17 @@
46
 
47
  function getMessages() {
48
  const placeholder = [{ role: "user", content: "Tell me a story" }];
49
- let messages = conversation.messages;
 
50
  if (messages.length === 1 && messages[0].role === "user" && !messages[0].content) {
51
  messages = placeholder;
52
  }
 
 
 
 
 
 
53
  return messages;
54
  }
55
 
 
8
  import http from "highlight.js/lib/languages/http";
9
 
10
  import IconCopyCode from "../Icons/IconCopyCode.svelte";
11
+ import { isSystemPromptSupported } from "./inferencePlaygroundUtils";
12
 
13
  hljs.registerLanguage("javascript", javascript);
14
  hljs.registerLanguage("python", python);
 
47
 
48
  function getMessages() {
49
  const placeholder = [{ role: "user", content: "Tell me a story" }];
50
+
51
+ let messages = [...conversation.messages];
52
  if (messages.length === 1 && messages[0].role === "user" && !messages[0].content) {
53
  messages = placeholder;
54
  }
55
+
56
+ const { model, systemMessage } = conversation;
57
+ if (isSystemPromptSupported(model) && systemMessage.content?.length) {
58
+ messages.unshift(systemMessage);
59
+ }
60
+
61
  return messages;
62
  }
63
 
src/lib/components/InferencePlayground/inferencePlaygroundUtils.ts CHANGED
@@ -11,17 +11,17 @@ export async function handleStreamingResponse(
11
  hf: HfInference,
12
  conversation: Conversation,
13
  onChunk: (content: string) => void,
14
- abortController: AbortController,
15
- systemMessage?: ChatCompletionInputMessage
16
  ): Promise<void> {
 
17
  const messages = [
18
- ...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length ? [systemMessage] : []),
19
  ...conversation.messages,
20
  ];
21
  let out = "";
22
  for await (const chunk of hf.chatCompletionStream(
23
  {
24
- model: conversation.model.id,
25
  messages,
26
  temperature: conversation.config.temperature,
27
  max_tokens: conversation.config.maxTokens,
@@ -37,16 +37,16 @@ export async function handleStreamingResponse(
37
 
38
  export async function handleNonStreamingResponse(
39
  hf: HfInference,
40
- conversation: Conversation,
41
- systemMessage?: ChatCompletionInputMessage
42
  ): Promise<ChatCompletionInputMessage> {
 
43
  const messages = [
44
- ...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length ? [systemMessage] : []),
45
  ...conversation.messages,
46
  ];
47
 
48
  const response = await hf.chatCompletion({
49
- model: conversation.model,
50
  messages,
51
  temperature: conversation.config.temperature,
52
  max_tokens: conversation.config.maxTokens,
 
11
  hf: HfInference,
12
  conversation: Conversation,
13
  onChunk: (content: string) => void,
14
+ abortController: AbortController
 
15
  ): Promise<void> {
16
+ const { model, systemMessage } = conversation;
17
  const messages = [
18
+ ...(isSystemPromptSupported(model) && systemMessage.content?.length ? [systemMessage] : []),
19
  ...conversation.messages,
20
  ];
21
  let out = "";
22
  for await (const chunk of hf.chatCompletionStream(
23
  {
24
+ model: model.id,
25
  messages,
26
  temperature: conversation.config.temperature,
27
  max_tokens: conversation.config.maxTokens,
 
37
 
38
  export async function handleNonStreamingResponse(
39
  hf: HfInference,
40
+ conversation: Conversation
 
41
  ): Promise<ChatCompletionInputMessage> {
42
+ const { model, systemMessage } = conversation;
43
  const messages = [
44
+ ...(isSystemPromptSupported(model) && systemMessage.content?.length ? [systemMessage] : []),
45
  ...conversation.messages,
46
  ];
47
 
48
  const response = await hf.chatCompletion({
49
+ model: model.id,
50
  messages,
51
  temperature: conversation.config.temperature,
52
  max_tokens: conversation.config.maxTokens,
src/lib/components/InferencePlayground/types.ts CHANGED
@@ -6,6 +6,7 @@ export type Conversation = {
6
  model: ModelEntryWithTokenizer;
7
  config: GenerationConfig;
8
  messages: ChatCompletionInputMessage[];
 
9
  streaming: boolean;
10
  };
11
 
 
6
  model: ModelEntryWithTokenizer;
7
  config: GenerationConfig;
8
  messages: ChatCompletionInputMessage[];
9
+ systemMessage: ChatCompletionInputMessage;
10
  streaming: boolean;
11
  };
12