smgc commited on
Commit
f21f2de
1 Parent(s): de8f523

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -49
app.py CHANGED
@@ -1,12 +1,10 @@
1
- from flask import Flask, request, Response, jsonify
2
  import os
3
- import requests
4
  import time
5
- import threading
 
 
 
6
 
7
- app = Flask(__name__)
8
-
9
- # 环境变量
10
  PROJECT_ID = os.getenv('PROJECT_ID')
11
  CLIENT_ID = os.getenv('CLIENT_ID')
12
  CLIENT_SECRET = os.getenv('CLIENT_SECRET')
@@ -21,36 +19,31 @@ token_cache = {
21
  'refresh_promise': None
22
  }
23
 
24
- def get_access_token():
25
  now = time.time()
26
 
27
- # 如果 token 仍然有效,直接返回
28
  if token_cache['access_token'] and now < token_cache['expiry'] - 120:
29
  return token_cache['access_token']
30
 
31
- # 如果已经有一个刷新操作在进行中,等待它完成
32
  if token_cache['refresh_promise']:
33
- token_cache['refresh_promise'].join()
34
  return token_cache['access_token']
35
 
36
- # 开始新的刷新操作
37
- def refresh_token():
38
- try:
39
- response = requests.post(TOKEN_URL, json={
40
  'client_id': CLIENT_ID,
41
  'client_secret': CLIENT_SECRET,
42
  'refresh_token': REFRESH_TOKEN,
43
  'grant_type': 'refresh_token'
44
- })
45
- data = response.json()
46
- token_cache['access_token'] = data['access_token']
47
- token_cache['expiry'] = now + data['expires_in']
48
- finally:
49
- token_cache['refresh_promise'] = None
50
-
51
- token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
52
- token_cache['refresh_promise'].start()
53
- token_cache['refresh_promise'].join()
54
  return token_cache['access_token']
55
 
56
  def get_location():
@@ -60,48 +53,62 @@ def get_location():
60
  def construct_api_url(location):
61
  return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
62
 
63
- @app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
64
- def handle_request():
65
  if request.method == 'OPTIONS':
66
  return handle_options()
67
 
68
- # 检查 x-api-key
69
  api_key = request.headers.get('x-api-key')
70
  if api_key != API_KEY:
71
- return jsonify({
72
- 'type': 'error',
73
- 'error': {
74
- 'type': 'permission_error',
75
- 'message': 'Your API key does not have permission to use the specified resource.'
76
- }
77
- }), 403
78
-
79
- access_token = get_access_token()
 
 
 
 
 
 
 
 
80
  location = get_location()
81
- model = request.headers.get('model', 'claude-3-5-sonnet@20240620')
82
- if model == 'claude-3-5-sonnet-20240620':
83
- model = 'claude-3-5-sonnet@20240620'
84
-
85
  api_url = construct_api_url(location)
86
 
87
- request_body = request.json
 
88
  if 'anthropic_version' in request_body:
89
  del request_body['anthropic_version']
90
  if 'model' in request_body:
91
  del request_body['model']
92
- request_body['anthropic_version'] = "vertex-2023-10-16"
 
93
 
94
  headers = {
95
  'Authorization': f'Bearer {access_token}',
96
  'Content-Type': 'application/json; charset=utf-8'
97
  }
98
 
99
- response = requests.post(api_url, headers=headers, json=request_body)
100
- return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- @app.route('/', methods=['GET'])
103
- def index():
104
- return "Vertex Claude API Proxy", 200
105
 
106
  def handle_options():
107
  headers = {
@@ -109,7 +116,10 @@ def handle_options():
109
  'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
110
  'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
111
  }
112
- return '', 204, headers
 
 
 
113
 
114
  if __name__ == '__main__':
115
- app.run(port=8080)
 
 
1
  import os
 
2
  import time
3
+ import json
4
+ import asyncio
5
+ import aiohttp
6
+ from aiohttp import web
7
 
 
 
 
8
  PROJECT_ID = os.getenv('PROJECT_ID')
9
  CLIENT_ID = os.getenv('CLIENT_ID')
10
  CLIENT_SECRET = os.getenv('CLIENT_SECRET')
 
19
  'refresh_promise': None
20
  }
21
 
22
+ async def get_access_token():
23
  now = time.time()
24
 
 
25
  if token_cache['access_token'] and now < token_cache['expiry'] - 120:
26
  return token_cache['access_token']
27
 
 
28
  if token_cache['refresh_promise']:
29
+ await token_cache['refresh_promise']
30
  return token_cache['access_token']
31
 
32
+ async def refresh_token():
33
+ async with aiohttp.ClientSession() as session:
34
+ async with session.post(TOKEN_URL, json={
 
35
  'client_id': CLIENT_ID,
36
  'client_secret': CLIENT_SECRET,
37
  'refresh_token': REFRESH_TOKEN,
38
  'grant_type': 'refresh_token'
39
+ }) as response:
40
+ data = await response.json()
41
+ token_cache['access_token'] = data['access_token']
42
+ token_cache['expiry'] = now + data['expires_in']
43
+
44
+ token_cache['refresh_promise'] = refresh_token()
45
+ await token_cache['refresh_promise']
46
+ token_cache['refresh_promise'] = None
 
 
47
  return token_cache['access_token']
48
 
49
  def get_location():
 
53
  def construct_api_url(location):
54
  return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
55
 
56
+ async def handle_request(request):
 
57
  if request.method == 'OPTIONS':
58
  return handle_options()
59
 
 
60
  api_key = request.headers.get('x-api-key')
61
  if api_key != API_KEY:
62
+ error_response = web.Response(
63
+ text=json.dumps({
64
+ 'type': 'error',
65
+ 'error': {
66
+ 'type': 'permission_error',
67
+ 'message': 'Your API key does not have permission to use the specified resource.'
68
+ }
69
+ }),
70
+ status=403,
71
+ content_type='application/json'
72
+ )
73
+ error_response.headers['Access-Control-Allow-Origin'] = '*'
74
+ error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
75
+ error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
76
+ return error_response
77
+
78
+ access_token = await get_access_token()
79
  location = get_location()
 
 
 
 
80
  api_url = construct_api_url(location)
81
 
82
+ request_body = await request.json()
83
+
84
  if 'anthropic_version' in request_body:
85
  del request_body['anthropic_version']
86
  if 'model' in request_body:
87
  del request_body['model']
88
+
89
+ request_body['anthropic_version'] = 'vertex-2023-10-16'
90
 
91
  headers = {
92
  'Authorization': f'Bearer {access_token}',
93
  'Content-Type': 'application/json; charset=utf-8'
94
  }
95
 
96
+ async with aiohttp.ClientSession() as session:
97
+ async with session.post(api_url, json=request_body, headers=headers) as response:
98
+ response_body = await response.read()
99
+ response_headers = response.headers
100
+ response_status = response.status
101
+
102
+ modified_response = web.Response(
103
+ body=response_body,
104
+ status=response_status,
105
+ headers=response_headers
106
+ )
107
+ modified_response.headers['Access-Control-Allow-Origin'] = '*'
108
+ modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
109
+ modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
110
 
111
+ return modified_response
 
 
112
 
113
  def handle_options():
114
  headers = {
 
116
  'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
117
  'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
118
  }
119
+ return web.Response(status=204, headers=headers)
120
+
121
+ app = web.Application()
122
+ app.router.add_route('*', '/', handle_request)
123
 
124
  if __name__ == '__main__':
125
+ web.run_app(app, port=8080)