tanveeshsingh commited on
Commit
36aee66
1 Parent(s): 9ceec3e
Files changed (1) hide show
  1. app.py +3 -26
app.py CHANGED
@@ -2,12 +2,10 @@ import gradio as gr
2
  from jinja2 import Template
3
  import openai
4
 
5
- # Replace this with your actual Hugging Face inference endpoint
6
  API_ENDPOINT = "https://txl0ptjvttfogwt9.us-east-1.aws.endpoints.huggingface.cloud/v1/"
7
  API_KEY = "NA"
8
 
9
  def classify_prompt(conv_prefix, response):
10
- # Message template for the system prompt
11
  template_str = """
12
  Please review the scoring criteria:
13
  <BEGIN SCORING CRITERIA>
@@ -28,47 +26,26 @@ Now, please output the following as a JSON object:
28
  {
29
  "output": <classification score (0 or 1)> (int datatype),
30
  }"""
31
-
32
- # Render the template with conversation context
33
  template = Template(template_str)
34
  rendered_message = template.render(conv_prefix=conv_prefix, response=response)
35
-
36
- # Call OpenAI API (assuming you have a client set up)
37
  client = openai.OpenAI(
38
  base_url=API_ENDPOINT,
39
  api_key=API_KEY
40
  )
41
-
42
- # Prepare message payload
43
  messages = [
44
  {"role": "system", "content": "\nYou will classify a prompt as either safe or unsafe based on predefined safety criteria. Follow these steps:\n\n1. **Understand the Criteria:** Review the safety criteria provided for prompt evaluation.\n2. **Evaluate the Prompt:** Assess the prompt against these criteria.\n3. **Classify the Prompt:** Determine whether the prompt is safe (1) or unsafe (0).\n4. **Provide Rationale:** Explain your classification decision with a brief rationale.\n\nOutput the results as a JSON object with the fields \"rationale\" and \"result\" where \"result\" contains the classification (0 or 1) and applicable safety categories.\n"},
45
  {"role": "user", "content": rendered_message}
46
  ]
47
-
48
- # Make a request to the OpenAI API
49
  completion = client.chat.completions.create(
50
  model="tgi",
51
  messages=messages,
52
  max_tokens=400
53
  )
54
-
55
- # Get the response content
56
  output = completion.choices[0].message.content
57
-
58
  return output
59
 
60
- def process_inputs(conv_prefix_text, response_content):
61
- # Process the input conversation prefix as a list of dictionaries
62
- conv_prefix = []
63
- for line in conv_prefix_text.split("\n"):
64
- if ": " in line:
65
- role, content = line.split(": ", 1)
66
- conv_prefix.append({"role": role.strip(), "content": content.strip()})
67
-
68
- # Process the assistant's response as a dictionary
69
  response = {"role": "assistant", "content": response_content}
70
-
71
- # Call classify_prompt with the structured data
72
  output = classify_prompt(conv_prefix, response)
73
  return output
74
 
@@ -76,12 +53,12 @@ def process_inputs(conv_prefix_text, response_content):
76
  demo = gr.Interface(
77
  fn=process_inputs,
78
  inputs=[
79
- gr.Textbox(lines=8, placeholder="Enter conversation prefix (role: content), one per line", label="Conversation Prefix"),
80
  gr.Textbox(lines=2, placeholder="Enter the assistant's response", label="Assistant Response")
81
  ],
82
  outputs="text",
83
  title="Prompt Safety Classification",
84
- description="Classify a conversation prompt's safety by providing a conversation prefix and an assistant's response."
85
  )
86
 
87
  demo.launch()
 
2
  from jinja2 import Template
3
  import openai
4
 
 
5
  API_ENDPOINT = "https://txl0ptjvttfogwt9.us-east-1.aws.endpoints.huggingface.cloud/v1/"
6
  API_KEY = "NA"
7
 
8
  def classify_prompt(conv_prefix, response):
 
9
  template_str = """
10
  Please review the scoring criteria:
11
  <BEGIN SCORING CRITERIA>
 
26
  {
27
  "output": <classification score (0 or 1)> (int datatype),
28
  }"""
 
 
29
  template = Template(template_str)
30
  rendered_message = template.render(conv_prefix=conv_prefix, response=response)
 
 
31
  client = openai.OpenAI(
32
  base_url=API_ENDPOINT,
33
  api_key=API_KEY
34
  )
 
 
35
  messages = [
36
  {"role": "system", "content": "\nYou will classify a prompt as either safe or unsafe based on predefined safety criteria. Follow these steps:\n\n1. **Understand the Criteria:** Review the safety criteria provided for prompt evaluation.\n2. **Evaluate the Prompt:** Assess the prompt against these criteria.\n3. **Classify the Prompt:** Determine whether the prompt is safe (1) or unsafe (0).\n4. **Provide Rationale:** Explain your classification decision with a brief rationale.\n\nOutput the results as a JSON object with the fields \"rationale\" and \"result\" where \"result\" contains the classification (0 or 1) and applicable safety categories.\n"},
37
  {"role": "user", "content": rendered_message}
38
  ]
 
 
39
  completion = client.chat.completions.create(
40
  model="tgi",
41
  messages=messages,
42
  max_tokens=400
43
  )
 
 
44
  output = completion.choices[0].message.content
 
45
  return output
46
 
47
+ def process_inputs(conv_prefix, response_content):
 
 
 
 
 
 
 
 
48
  response = {"role": "assistant", "content": response_content}
 
 
49
  output = classify_prompt(conv_prefix, response)
50
  return output
51
 
 
53
  demo = gr.Interface(
54
  fn=process_inputs,
55
  inputs=[
56
+ gr.JSON(label="Conversation Prefix (Array of Objects)"),
57
  gr.Textbox(lines=2, placeholder="Enter the assistant's response", label="Assistant Response")
58
  ],
59
  outputs="text",
60
  title="Prompt Safety Classification",
61
+ description="Classify a conversation prompt's safety by providing a conversation prefix (array of objects) and an assistant's response."
62
  )
63
 
64
  demo.launch()