smgc commited on
Commit
b73bb4c
1 Parent(s): 6422859

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -58
app.py CHANGED
@@ -8,14 +8,24 @@ import requests
8
  import logging
9
  from threading import Event
10
 
 
 
 
11
  app = Flask(__name__)
12
  logging.basicConfig(level=logging.INFO)
13
 
 
14
  API_KEY = os.environ.get('PPLX_KEY')
 
 
15
  proxy_url = os.environ.get('PROXY_URL')
16
 
 
17
  if proxy_url:
18
- proxies = {'http': proxy_url, 'https': proxy_url}
 
 
 
19
  transport = requests.Session()
20
  transport.proxies.update(proxies)
21
  else:
@@ -23,7 +33,12 @@ else:
23
 
24
  sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
25
 
26
- connect_opts = {'transports': ['websocket', 'polling']}
 
 
 
 
 
27
  sio_opts = {
28
  'extraHeaders': {
29
  'Cookie': os.environ.get('PPLX_COOKIE'),
@@ -46,24 +61,31 @@ def validate_api_key():
46
  return None
47
 
48
  def normalize_content(content):
 
 
 
 
49
  if isinstance(content, str):
50
  return content
51
  elif isinstance(content, dict):
 
52
  return json.dumps(content, ensure_ascii=False)
53
  elif isinstance(content, list):
 
54
  return " ".join([normalize_content(item) for item in content])
55
  else:
 
56
  return ""
57
 
58
  def calculate_tokens(text):
59
- return len(text.split())
60
-
61
- def validate_json(data):
62
- try:
63
- json.loads(json.dumps(data))
64
- return True
65
- except json.JSONDecodeError:
66
- return False
67
 
68
  @app.route('/')
69
  def root():
@@ -95,10 +117,13 @@ def messages():
95
 
96
  try:
97
  json_body = request.json
98
- model = json_body.get('model', 'claude-3-opus-20240229')
99
- stream = json_body.get('stream', True)
100
 
 
101
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
 
 
102
  input_tokens = calculate_tokens(previous_messages)
103
 
104
  msg_id = str(uuid.uuid4())
@@ -106,8 +131,10 @@ def messages():
106
  response_text = []
107
 
108
  if not stream:
 
109
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
110
 
 
111
  log_request(request.remote_addr, request.path, 200)
112
 
113
  def generate():
@@ -118,10 +145,10 @@ def messages():
118
  "type": "message",
119
  "role": "assistant",
120
  "content": [],
121
- "model": model,
122
  "stop_reason": None,
123
  "stop_sequence": None,
124
- "usage": {"input_tokens": input_tokens, "output_tokens": 1},
125
  },
126
  })
127
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
@@ -147,16 +174,18 @@ def messages():
147
 
148
  def on_query_progress(data):
149
  nonlocal response_text
150
- try:
151
- if 'text' in data:
152
- text = json.loads(data['text'])
153
- chunk = text['chunks'][-1] if text['chunks'] else None
154
- if chunk:
155
- response_text.append(chunk)
156
- if data.get('final', False):
157
- response_event.set()
158
- except json.JSONDecodeError:
159
- logging.error(f"Failed to parse query progress data: {data}")
 
 
160
 
161
  def on_disconnect():
162
  logging.info("Disconnected from Perplexity AI")
@@ -169,6 +198,7 @@ def messages():
169
 
170
  sio.on('connect', on_connect)
171
  sio.on('query_progress', on_query_progress)
 
172
  sio.on('disconnect', on_disconnect)
173
  sio.on('connect_error', on_connect_error)
174
 
@@ -179,15 +209,11 @@ def messages():
179
  sio.sleep(0.1)
180
  while response_text:
181
  chunk = response_text.pop(0)
182
- event_data = {
183
  "type": "content_block_delta",
184
  "index": 0,
185
  "delta": {"type": "text_delta", "text": chunk},
186
- }
187
- if validate_json(event_data):
188
- yield create_event("content_block_delta", event_data)
189
- else:
190
- logging.error(f"Invalid JSON for content_block_delta: {event_data}")
191
 
192
  except Exception as e:
193
  logging.error(f"Error during socket connection: {str(e)}")
@@ -200,15 +226,16 @@ def messages():
200
  if sio.connected:
201
  sio.disconnect()
202
 
 
203
  output_tokens = calculate_tokens(''.join(response_text))
204
 
205
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
206
  yield create_event("message_delta", {
207
  "type": "message_delta",
208
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
209
- "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
210
  })
211
- yield create_event("message_stop", {"type": "message_stop"})
212
 
213
  return Response(generate(), content_type='text/event-stream')
214
 
@@ -218,6 +245,9 @@ def messages():
218
  return jsonify({"error": str(e)}), 400
219
 
220
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
 
 
 
221
  try:
222
  response_event = Event()
223
  response_text = []
@@ -242,16 +272,15 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
242
 
243
  def on_query_progress(data):
244
  nonlocal response_text
245
- try:
246
- if 'text' in data:
247
- text = json.loads(data['text'])
248
- chunk = text['chunks'][-1] if text['chunks'] else None
249
- if chunk:
250
- response_text.append(chunk)
251
- if data.get('final', False):
252
- response_event.set()
253
- except json.JSONDecodeError:
254
- logging.error(f"Failed to parse query progress data: {data}")
255
 
256
  def on_disconnect():
257
  logging.info("Disconnected from Perplexity AI")
@@ -269,28 +298,26 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
269
 
270
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
271
 
 
272
  response_event.wait(timeout=30)
273
 
 
274
  output_tokens = calculate_tokens(''.join(response_text))
275
 
 
276
  full_response = {
277
- "content": [{"text": ''.join(response_text), "type": "text"}],
278
  "id": msg_id,
279
- "model": model,
280
  "role": "assistant",
281
  "stop_reason": "end_turn",
282
  "stop_sequence": None,
283
  "type": "message",
284
  "usage": {
285
- "input_tokens": input_tokens,
286
- "output_tokens": output_tokens,
287
  },
288
  }
289
-
290
- if not validate_json(full_response):
291
- logging.error(f"Invalid JSON response: {full_response}")
292
- return jsonify({"error": "Invalid response format"}), 500
293
-
294
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
295
 
296
  except Exception as e:
@@ -312,13 +339,9 @@ def server_error(error):
312
  return "Something broke!", 500
313
 
314
  def create_event(event, data):
315
- try:
316
- if isinstance(data, dict):
317
- data = json.dumps(data, ensure_ascii=False)
318
- return f"event: {event}\ndata: {data}\n\n"
319
- except json.JSONDecodeError:
320
- logging.error(f"Failed to serialize event data: {data}")
321
- return f"event: {event}\ndata: {json.dumps({'error': 'Data serialization failed'})}\n\n"
322
 
323
  if __name__ == '__main__':
324
  port = int(os.environ.get('PORT', 8081))
 
8
  import logging
9
  from threading import Event
10
 
11
+ # 如果使用 GPT 模型的 tokenization,可以引入 tiktoken
12
+ # import tiktoken # 如果需要使用 GPT 的 token 化库
13
+
14
  app = Flask(__name__)
15
  logging.basicConfig(level=logging.INFO)
16
 
17
+ # 从环境变量中获取API密钥
18
  API_KEY = os.environ.get('PPLX_KEY')
19
+
20
+ # 代理设置
21
  proxy_url = os.environ.get('PROXY_URL')
22
 
23
+ # 设置代理
24
  if proxy_url:
25
+ proxies = {
26
+ 'http': proxy_url,
27
+ 'https': proxy_url
28
+ }
29
  transport = requests.Session()
30
  transport.proxies.update(proxies)
31
  else:
 
33
 
34
  sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
35
 
36
+ # 连接选项
37
+ connect_opts = {
38
+ 'transports': ['websocket', 'polling'], # 允许回退到轮询
39
+ }
40
+
41
+ # 其他选项
42
  sio_opts = {
43
  'extraHeaders': {
44
  'Cookie': os.environ.get('PPLX_COOKIE'),
 
61
  return None
62
 
63
  def normalize_content(content):
64
+ """
65
+ 递归处理 msg['content'],确保其为字符串。
66
+ 如果 content 是字典或列表,将其转换为字符串。
67
+ """
68
  if isinstance(content, str):
69
  return content
70
  elif isinstance(content, dict):
71
+ # 将字典转化为 JSON 字符串
72
  return json.dumps(content, ensure_ascii=False)
73
  elif isinstance(content, list):
74
+ # 对于列表,递归处理每个元素
75
  return " ".join([normalize_content(item) for item in content])
76
  else:
77
+ # 如果是其他类型,返回空字符串
78
  return ""
79
 
80
  def calculate_tokens(text):
81
+ """
82
+ 计算文本的 token 数量。
83
+ 这里我们简单地通过空格分词来模拟 token 计数。
84
+ 如果使用 GPT 模型,可以使用 tiktoken 库进行 tokenization。
85
+ """
86
+ # 使用简单的空格分词计数
87
+ tokens = text.split()
88
+ return len(tokens)
89
 
90
  @app.route('/')
91
  def root():
 
117
 
118
  try:
119
  json_body = request.json
120
+ model = json_body.get('model', 'claude-3-opus-20240229') # 动态获取模型,默认 claude-3-opus-20240229
121
+ stream = json_body.get('stream', True) # 默认为True
122
 
123
+ # 使用 normalize_content 递归处理 msg['content']
124
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
125
+
126
+ # 动态计算输入的 token 数量
127
  input_tokens = calculate_tokens(previous_messages)
128
 
129
  msg_id = str(uuid.uuid4())
 
131
  response_text = []
132
 
133
  if not stream:
134
+ # 处理 stream 为 false 的情况
135
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
136
 
137
+ # 记录日志:此时请求上下文仍然有效
138
  log_request(request.remote_addr, request.path, 200)
139
 
140
  def generate():
 
145
  "type": "message",
146
  "role": "assistant",
147
  "content": [],
148
+ "model": model, # 动态模型
149
  "stop_reason": None,
150
  "stop_sequence": None,
151
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
152
  },
153
  })
154
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
174
 
175
  def on_query_progress(data):
176
  nonlocal response_text
177
+ if 'text' in data:
178
+ text = json.loads(data['text'])
179
+ chunk = text['chunks'][-1] if text['chunks'] else None
180
+ if chunk:
181
+ response_text.append(chunk)
182
+
183
+ # 检查是否是最终响应
184
+ if data.get('final', False):
185
+ response_event.set()
186
+
187
+ def on_query_complete(data):
188
+ response_event.set()
189
 
190
  def on_disconnect():
191
  logging.info("Disconnected from Perplexity AI")
 
198
 
199
  sio.on('connect', on_connect)
200
  sio.on('query_progress', on_query_progress)
201
+ sio.on('query_complete', on_query_complete)
202
  sio.on('disconnect', on_disconnect)
203
  sio.on('connect_error', on_connect_error)
204
 
 
209
  sio.sleep(0.1)
210
  while response_text:
211
  chunk = response_text.pop(0)
212
+ yield create_event("content_block_delta", {
213
  "type": "content_block_delta",
214
  "index": 0,
215
  "delta": {"type": "text_delta", "text": chunk},
216
+ })
 
 
 
 
217
 
218
  except Exception as e:
219
  logging.error(f"Error during socket connection: {str(e)}")
 
226
  if sio.connected:
227
  sio.disconnect()
228
 
229
+ # 动态计算输出的 token 数量
230
  output_tokens = calculate_tokens(''.join(response_text))
231
 
232
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
233
  yield create_event("message_delta", {
234
  "type": "message_delta",
235
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
236
+ "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, # 动态 output_tokens
237
  })
238
+ yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
239
 
240
  return Response(generate(), content_type='text/event-stream')
241
 
 
245
  return jsonify({"error": str(e)}), 400
246
 
247
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
248
+ """
249
+ 处理 stream 为 false 的情况,返回完整的响应。
250
+ """
251
  try:
252
  response_event = Event()
253
  response_text = []
 
272
 
273
  def on_query_progress(data):
274
  nonlocal response_text
275
+ if 'text' in data:
276
+ text = json.loads(data['text'])
277
+ chunk = text['chunks'][-1] if text['chunks'] else None
278
+ if chunk:
279
+ response_text.append(chunk)
280
+
281
+ # 检查是否是最终响应
282
+ if data.get('final', False):
283
+ response_event.set()
 
284
 
285
  def on_disconnect():
286
  logging.info("Disconnected from Perplexity AI")
 
298
 
299
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
300
 
301
+ # 等待响应完成
302
  response_event.wait(timeout=30)
303
 
304
+ # 动态计算输出的 token 数量
305
  output_tokens = calculate_tokens(''.join(response_text))
306
 
307
+ # 生成完整的响应
308
  full_response = {
309
+ "content": [{"text": ''.join(response_text), "type": "text"}], # 合并所有文本块
310
  "id": msg_id,
311
+ "model": model, # 动态模型
312
  "role": "assistant",
313
  "stop_reason": "end_turn",
314
  "stop_sequence": None,
315
  "type": "message",
316
  "usage": {
317
+ "input_tokens": input_tokens, # 动态 input_tokens
318
+ "output_tokens": output_tokens, # 动态 output_tokens
319
  },
320
  }
 
 
 
 
 
321
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
322
 
323
  except Exception as e:
 
339
  return "Something broke!", 500
340
 
341
  def create_event(event, data):
342
+ if isinstance(data, dict):
343
+ data = json.dumps(data, ensure_ascii=False) # 确保中文不会被转义
344
+ return f"event: {event}\ndata: {data}\n\n"
 
 
 
 
345
 
346
  if __name__ == '__main__':
347
  port = int(os.environ.get('PORT', 8081))