smgc commited on
Commit
10d3a03
1 Parent(s): 8eb01dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -8,6 +8,9 @@ import requests
8
  import logging
9
  from threading import Event
10
 
 
 
 
11
  app = Flask(__name__)
12
  logging.basicConfig(level=logging.INFO)
13
 
@@ -74,6 +77,16 @@ def normalize_content(content):
74
  # 如果是其他类型,返回空字符串
75
  return ""
76
 
 
 
 
 
 
 
 
 
 
 
77
  @app.route('/')
78
  def root():
79
  log_request(request.remote_addr, request.path, 200)
@@ -110,13 +123,16 @@ def messages():
110
  # 使用 normalize_content 递归处理 msg['content']
111
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
112
 
 
 
 
113
  msg_id = str(uuid.uuid4())
114
  response_event = Event()
115
  response_text = []
116
 
117
  if not stream:
118
  # 处理 stream 为 false 的情况
119
- return handle_non_stream(previous_messages, msg_id, model)
120
 
121
  # 记录日志:此时请求上下文仍然有效
122
  log_request(request.remote_addr, request.path, 200)
@@ -132,7 +148,7 @@ def messages():
132
  "model": model, # 动态模型
133
  "stop_reason": None,
134
  "stop_sequence": None,
135
- "usage": {"input_tokens": 8, "output_tokens": 1},
136
  },
137
  })
138
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
@@ -225,7 +241,7 @@ def messages():
225
  log_request(request.remote_addr, request.path, 400)
226
  return jsonify({"error": str(e)}), 400
227
 
228
- def handle_non_stream(previous_messages, msg_id, model):
229
  """
230
  处理 stream 为 false 的情况,返回完整的响应。
231
  """
@@ -292,7 +308,7 @@ def handle_non_stream(previous_messages, msg_id, model):
292
  "stop_sequence": None,
293
  "type": "message",
294
  "usage": {
295
- "input_tokens": 8,
296
  "output_tokens": len(''.join(response_text)),
297
  },
298
  }
 
8
  import logging
9
  from threading import Event
10
 
11
+ # 如果使用 GPT 模型的 tokenization,可以引入 tiktoken
12
+ # import tiktoken # 如果需要使用 GPT 的 token 化库
13
+
14
  app = Flask(__name__)
15
  logging.basicConfig(level=logging.INFO)
16
 
 
77
  # 如果是其他类型,返回空字符串
78
  return ""
79
 
80
+ def calculate_input_tokens(text):
81
+ """
82
+ 计算输入文本的 token 数量。
83
+ 这里我们简单地通过空格分词来模拟 token 计数。
84
+ 如果使用 GPT 模型,可以使用 tiktoken 库进行 tokenization。
85
+ """
86
+ # 使用简单的空格分词计数
87
+ tokens = text.split()
88
+ return len(tokens)
89
+
90
  @app.route('/')
91
  def root():
92
  log_request(request.remote_addr, request.path, 200)
 
123
  # 使用 normalize_content 递归处理 msg['content']
124
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
125
 
126
+ # 动态计算输入的 token 数量
127
+ input_tokens = calculate_input_tokens(previous_messages)
128
+
129
  msg_id = str(uuid.uuid4())
130
  response_event = Event()
131
  response_text = []
132
 
133
  if not stream:
134
  # 处理 stream 为 false 的情况
135
+ return handle_non_stream(previous_messages, msg_id, model, input_tokens)
136
 
137
  # 记录日志:此时请求上下文仍然有效
138
  log_request(request.remote_addr, request.path, 200)
 
148
  "model": model, # 动态模型
149
  "stop_reason": None,
150
  "stop_sequence": None,
151
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
152
  },
153
  })
154
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
241
  log_request(request.remote_addr, request.path, 400)
242
  return jsonify({"error": str(e)}), 400
243
 
244
+ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
245
  """
246
  处理 stream 为 false 的情况,返回完整的响应。
247
  """
 
308
  "stop_sequence": None,
309
  "type": "message",
310
  "usage": {
311
+ "input_tokens": input_tokens, # 动态 input_tokens
312
  "output_tokens": len(''.join(response_text)),
313
  },
314
  }