smgc commited on
Commit
3fdf5e9
1 Parent(s): 1297078

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, Response, jsonify
2
+ import os
3
+ import requests
4
+ import time
5
+ import threading
6
+
7
+ app = Flask(__name__)
8
+
9
+ # 环境变量
10
+ PROJECT_ID = os.getenv('PROJECT_ID')
11
+ CLIENT_ID = os.getenv('CLIENT_ID')
12
+ CLIENT_SECRET = os.getenv('CLIENT_SECRET')
13
+ REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
14
+ API_KEY = os.getenv('API_KEY')
15
+
16
+ TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
17
+
18
+ token_cache = {
19
+ 'access_token': '',
20
+ 'expiry': 0,
21
+ 'refresh_promise': None
22
+ }
23
+
24
+ def get_access_token():
25
+ now = time.time()
26
+
27
+ # 如果 token 仍然有效,直接返回
28
+ if token_cache['access_token'] and now < token_cache['expiry'] - 120:
29
+ return token_cache['access_token']
30
+
31
+ # 如果已经有一个刷新操作在进行中,等待它完成
32
+ if token_cache['refresh_promise']:
33
+ token_cache['refresh_promise'].join()
34
+ return token_cache['access_token']
35
+
36
+ # 开始新的刷新操作
37
+ def refresh_token():
38
+ try:
39
+ response = requests.post(TOKEN_URL, json={
40
+ 'client_id': CLIENT_ID,
41
+ 'client_secret': CLIENT_SECRET,
42
+ 'refresh_token': REFRESH_TOKEN,
43
+ 'grant_type': 'refresh_token'
44
+ })
45
+ data = response.json()
46
+ token_cache['access_token'] = data['access_token']
47
+ token_cache['expiry'] = now + data['expires_in']
48
+ finally:
49
+ token_cache['refresh_promise'] = None
50
+
51
+ token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
52
+ token_cache['refresh_promise'].start()
53
+ token_cache['refresh_promise'].join()
54
+ return token_cache['access_token']
55
+
56
+ def get_location():
57
+ current_seconds = time.localtime().tm_sec
58
+ return 'europe-west1' if current_seconds < 30 else 'us-east5'
59
+
60
+ def construct_api_url(location):
61
+ return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
62
+
63
+ @app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
64
+ def handle_request():
65
+ if request.method == 'OPTIONS':
66
+ return handle_options()
67
+
68
+ # 检查 x-api-key
69
+ api_key = request.headers.get('x-api-key')
70
+ if api_key != API_KEY:
71
+ return jsonify({
72
+ 'type': 'error',
73
+ 'error': {
74
+ 'type': 'permission_error',
75
+ 'message': 'Your API key does not have permission to use the specified resource.'
76
+ }
77
+ }), 403
78
+
79
+ access_token = get_access_token()
80
+ location = get_location()
81
+ model = request.headers.get('model', 'claude-3-5-sonnet@20240620')
82
+ if model == 'claude-3-5-sonnet-20240620':
83
+ model = 'claude-3-5-sonnet@20240620'
84
+
85
+ api_url = construct_api_url(location)
86
+
87
+ request_body = request.json
88
+ if 'anthropic_version' in request_body:
89
+ del request_body['anthropic_version']
90
+ if 'model' in request_body:
91
+ del request_body['model']
92
+ request_body['anthropic_version'] = "vertex-2023-10-16"
93
+
94
+ headers = {
95
+ 'Authorization': f'Bearer {access_token}',
96
+ 'Content-Type': 'application/json; charset=utf-8'
97
+ }
98
+
99
+ response = requests.post(api_url, headers=headers, json=request_body)
100
+ return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])
101
+
102
+ @app.route('/', methods=['GET'])
103
+ def index():
104
+ return "Vertex Claude API Proxy", 200
105
+
106
+ def handle_options():
107
+ headers = {
108
+ 'Access-Control-Allow-Origin': '*',
109
+ 'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
110
+ 'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
111
+ }
112
+ return '', 204, headers
113
+
114
+ if __name__ == '__main__':
115
+ app.run(port=8080)