smgc commited on
Commit
1b7f3dd
1 Parent(s): 546eb40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -38
app.py CHANGED
@@ -7,32 +7,7 @@ import socketio
7
  import requests
8
  import logging
9
  from threading import Event
10
- import tiktoken # 引入 tiktoken 库
11
- from tiktoken import Encoding
12
-
13
- def local_encoding_for_model(model_name: str):
14
- """
15
- 从本地加载编码文件并返回一个 Encoding 对象。
16
- """
17
- local_encoding_path = '/app/cl100k_base.tiktoken'
18
- if os.path.exists(local_encoding_path):
19
- with open(local_encoding_path, 'rb') as f:
20
- encoding_data = f.read() # 读取本地编码文件的字节内容
21
-
22
- # 构造一个 Encoding 对象
23
- return Encoding(
24
- name="cl100k_base", # 编码的名称
25
- pat_str="", # 正则表达式(如果有)
26
- mergeable_ranks={}, # 合并的 rank 数据(通常是从文件或其他地方加载)
27
- special_tokens={}, # 特殊 token 映射
28
- explicit_n_vocab=None # 可选的词汇表大小
29
- )
30
- else:
31
- raise FileNotFoundError(f"Local encoding file not found at {local_encoding_path}")
32
-
33
- # 替换 tiktoken 的 encoding_for_model 函数
34
- tiktoken.encoding_for_model = local_encoding_for_model
35
-
36
 
37
  app = Flask(__name__)
38
  logging.basicConfig(level=logging.INFO)
@@ -100,14 +75,20 @@ def normalize_content(content):
100
  # 如果是其他类型,返回空字符串
101
  return ""
102
 
103
- def calculate_tokens_via_tiktoken(text, model="gpt-3.5-turbo"):
104
  """
105
- 使用 tiktoken 库根据 GPT 模型计算 token 数量。
106
- Claude 模型与 GPT 模型的 token 计算机制类似,因此可以使用 tiktoken。
 
107
  """
108
- encoding = tiktoken.encoding_for_model(model) # 获取模型的编码器
109
- tokens = encoding.encode(text) # 对文本进行 tokenization
110
- return len(tokens)
 
 
 
 
 
111
 
112
  @app.route('/')
113
  def root():
@@ -145,8 +126,8 @@ def messages():
145
  # 使用 normalize_content 递归处理 msg['content']
146
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
147
 
148
- # 动态计算输入的 token 数量,使用 tiktoken 进行 tokenization
149
- input_tokens = calculate_tokens_via_tiktoken(previous_messages, model="gpt-3.5-turbo")
150
 
151
  msg_id = str(uuid.uuid4())
152
  response_event = Event()
@@ -248,8 +229,8 @@ def messages():
248
  if sio.connected:
249
  sio.disconnect()
250
 
251
- # 动态计算输出的 token 数量,使用 tiktoken 进行 tokenization
252
- output_tokens = calculate_tokens_via_tiktoken(''.join(response_text), model="gpt-3.5-turbo")
253
 
254
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
255
  yield create_event("message_delta", {
@@ -323,8 +304,8 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
323
  # 等待响应完成
324
  response_event.wait(timeout=30)
325
 
326
- # 动态计算输出的 token 数量,使用 tiktoken 进行 tokenization
327
- output_tokens = calculate_tokens_via_tiktoken(''.join(response_text), model="gpt-3.5-turbo")
328
 
329
  # 生成完整的响应
330
  full_response = {
 
7
  import requests
8
  import logging
9
  from threading import Event
10
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  app = Flask(__name__)
13
  logging.basicConfig(level=logging.INFO)
 
75
  # 如果是其他类型,返回空字符串
76
  return ""
77
 
78
+ def calculate_tokens(text):
79
  """
80
+ 改进的 token 计算方法。
81
+ - 对于英文和有空格的文本,使用空格分词。
82
+ - 对于中文等没有空格的文本,使用字符级分词。
83
  """
84
+ # 首先判断文本是否包含大量非 ASCII 字符(如中文)
85
+ if re.search(r'[^\x00-\x7F]', text):
86
+ # 如果包含非 ASCII 字符,使用字符级分词
87
+ return len(text)
88
+ else:
89
+ # 否则使用空格分词
90
+ tokens = text.split()
91
+ return len(tokens)
92
 
93
  @app.route('/')
94
  def root():
 
126
  # 使用 normalize_content 递归处理 msg['content']
127
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
128
 
129
+ # 动态计算输入的 token 数量
130
+ input_tokens = calculate_tokens(previous_messages)
131
 
132
  msg_id = str(uuid.uuid4())
133
  response_event = Event()
 
229
  if sio.connected:
230
  sio.disconnect()
231
 
232
+ # 动态计算输出的 token 数量
233
+ output_tokens = calculate_tokens(''.join(response_text))
234
 
235
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
236
  yield create_event("message_delta", {
 
304
  # 等待响应完成
305
  response_event.wait(timeout=30)
306
 
307
+ # 动态计算输出的 token 数量
308
+ output_tokens = calculate_tokens(''.join(response_text))
309
 
310
  # 生成完整的响应
311
  full_response = {