smgc commited on
Commit
7164dd1
1 Parent(s): 10d3a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -7,9 +7,7 @@ import socketio
7
  import requests
8
  import logging
9
  from threading import Event
10
-
11
- # 如果使用 GPT 模型的 tokenization,可以引入 tiktoken
12
- # import tiktoken # 如果需要使用 GPT 的 token 化库
13
 
14
  app = Flask(__name__)
15
  logging.basicConfig(level=logging.INFO)
@@ -77,14 +75,14 @@ def normalize_content(content):
77
  # 如果是其他类型,返回空字符串
78
  return ""
79
 
80
- def calculate_input_tokens(text):
81
  """
82
- 计算输入文本的 token 数量。
83
- 这里我们简单地通过空格分词来模拟 token 计数。
84
- 如果使用 GPT 模型,可以使用 tiktoken 库进行 tokenization。
85
  """
86
- # 使用简单的空格分词计数
87
- tokens = text.split()
 
88
  return len(tokens)
89
 
90
  @app.route('/')
@@ -124,7 +122,7 @@ def messages():
124
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
125
 
126
  # 动态计算输入的 token 数量
127
- input_tokens = calculate_input_tokens(previous_messages)
128
 
129
  msg_id = str(uuid.uuid4())
130
  response_event = Event()
@@ -148,7 +146,7 @@ def messages():
148
  "model": model, # 动态模型
149
  "stop_reason": None,
150
  "stop_sequence": None,
151
- "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
152
  },
153
  })
154
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
@@ -226,11 +224,15 @@ def messages():
226
  if sio.connected:
227
  sio.disconnect()
228
 
 
 
 
 
229
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
230
  yield create_event("message_delta", {
231
  "type": "message_delta",
232
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
233
- "usage": {"output_tokens": len(''.join(response_text))},
234
  })
235
  yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
236
 
@@ -298,9 +300,13 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
298
  # 等待响应完成
299
  response_event.wait(timeout=30)
300
 
 
 
 
 
301
  # 生成完整的响应
302
  full_response = {
303
- "content": [{"text": ''.join(response_text), "type": "text"}], # 合并所有文本块
304
  "id": msg_id,
305
  "model": model, # 动态模型
306
  "role": "assistant",
@@ -309,7 +315,7 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
309
  "type": "message",
310
  "usage": {
311
  "input_tokens": input_tokens, # 动态 input_tokens
312
- "output_tokens": len(''.join(response_text)),
313
  },
314
  }
315
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
 
7
  import requests
8
  import logging
9
  from threading import Event
10
+ import tiktoken # 使用 tiktoken 库进行 tokenization
 
 
11
 
12
  app = Flask(__name__)
13
  logging.basicConfig(level=logging.INFO)
 
75
  # 如果是其他类型,返回空字符串
76
  return ""
77
 
78
+ def calculate_tokens(text):
79
  """
80
+ 使用 tiktoken 库来计算输入和输出文本的 token 数量。
81
+ Claude 模型可能有不同的 tokenization 规则,但 tiktoken 是一个很好的近似工具。
 
82
  """
83
+ # 使用 tiktoken 的 GPT-3.5 模型进行 tokenization
84
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # 使用 gpt-3.5-turbo 作为近似模型
85
+ tokens = encoding.encode(text)
86
  return len(tokens)
87
 
88
  @app.route('/')
 
122
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
123
 
124
  # 动态计算输入的 token 数量
125
+ input_tokens = calculate_tokens(previous_messages)
126
 
127
  msg_id = str(uuid.uuid4())
128
  response_event = Event()
 
146
  "model": model, # 动态模型
147
  "stop_reason": None,
148
  "stop_sequence": None,
149
+ "usage": {"input_tokens": input_tokens, "output_tokens": 0}, # 动态 input_tokens
150
  },
151
  })
152
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
224
  if sio.connected:
225
  sio.disconnect()
226
 
227
+ output_text = ''.join(response_text)
228
+ # 动态计算 output_tokens
229
+ output_tokens = calculate_tokens(output_text)
230
+
231
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
232
  yield create_event("message_delta", {
233
  "type": "message_delta",
234
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
235
+ "usage": {"output_tokens": output_tokens}, # 动态 output_tokens
236
  })
237
  yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
238
 
 
300
  # 等待响应完成
301
  response_event.wait(timeout=30)
302
 
303
+ output_text = ''.join(response_text)
304
+ # 动态计算 output_tokens
305
+ output_tokens = calculate_tokens(output_text)
306
+
307
  # 生成完整的响应
308
  full_response = {
309
+ "content": [{"text": output_text, "type": "text"}], # 合并所有文本块
310
  "id": msg_id,
311
  "model": model, # 动态模型
312
  "role": "assistant",
 
315
  "type": "message",
316
  "usage": {
317
  "input_tokens": input_tokens, # 动态 input_tokens
318
+ "output_tokens": output_tokens, # 动态 output_tokens
319
  },
320
  }
321
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')