Severian commited on
Commit
3e312b7
1 Parent(s): 58895ca

Update llm_handler.py

Browse files
Files changed (1) hide show
  1. llm_handler.py +30 -0
llm_handler.py CHANGED
@@ -1,9 +1,13 @@
1
  from openai import OpenAI
2
  from params import OPENAI_MODEL, OPENAI_API_KEY
 
3
 
4
  # Create an instance of the OpenAI class
5
  client = OpenAI(api_key=OPENAI_API_KEY)
6
 
 
 
 
7
  def send_to_chatgpt(msg_list):
8
  try:
9
  completion = client.chat.completions.create(
@@ -25,8 +29,34 @@ def send_to_chatgpt(msg_list):
25
  print(f"Error in send_to_chatgpt: {str(e)}")
26
  return f"Error: {str(e)}", None
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def send_to_llm(provider, msg_list):
29
  if provider == "llamanet":
 
 
30
  return send_to_chatgpt(msg_list)
31
  else:
32
  raise ValueError(f"Unknown provider: {provider}")
 
1
  from openai import OpenAI
2
  from params import OPENAI_MODEL, OPENAI_API_KEY
3
+ import llamanet
4
 
5
  # Create an instance of the OpenAI class
6
  client = OpenAI(api_key=OPENAI_API_KEY)
7
 
8
+ # Initialize LlamaNet client
9
+ llamanet_client = llamanet.Client()
10
+
11
  def send_to_chatgpt(msg_list):
12
  try:
13
  completion = client.chat.completions.create(
 
29
  print(f"Error in send_to_chatgpt: {str(e)}")
30
  return f"Error: {str(e)}", None
31
 
32
+ def send_to_llamanet(msg_list):
33
+ try:
34
+ # Convert msg_list to the format expected by LlamaNet
35
+ llamanet_messages = [{"role": msg["role"], "content": msg["content"]} for msg in msg_list]
36
+
37
+ # Send request to LlamaNet
38
+ response = llamanet_client.chat.completions.create(
39
+ model="llamanet",
40
+ messages=llamanet_messages,
41
+ stream=True
42
+ )
43
+
44
+ llamanet_response = ""
45
+ for chunk in response:
46
+ if chunk.choices[0].delta.content is not None:
47
+ llamanet_response += chunk.choices[0].delta.content
48
+
49
+ # LlamaNet doesn't provide usage information
50
+ llamanet_usage = None
51
+ return llamanet_response, llamanet_usage
52
+ except Exception as e:
53
+ print(f"Error in send_to_llamanet: {str(e)}")
54
+ return f"Error: {str(e)}", None
55
+
56
  def send_to_llm(provider, msg_list):
57
  if provider == "llamanet":
58
+ return send_to_llamanet(msg_list)
59
+ elif provider == "openai":
60
  return send_to_chatgpt(msg_list)
61
  else:
62
  raise ValueError(f"Unknown provider: {provider}")