File size: 3,732 Bytes
3fdf5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from flask import Flask, request, Response, jsonify
import os
import requests
import time
import threading

app = Flask(__name__)

# 环境变量
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')

TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'

token_cache = {
    'access_token': '',
    'expiry': 0,
    'refresh_promise': None
}

def get_access_token():
    now = time.time()

    # 如果 token 仍然有效,直接返回
    if token_cache['access_token'] and now < token_cache['expiry'] - 120:
        return token_cache['access_token']

    # 如果已经有一个刷新操作在进行中,等待它完成
    if token_cache['refresh_promise']:
        token_cache['refresh_promise'].join()
        return token_cache['access_token']

    # 开始新的刷新操作
    def refresh_token():
        try:
            response = requests.post(TOKEN_URL, json={
                'client_id': CLIENT_ID,
                'client_secret': CLIENT_SECRET,
                'refresh_token': REFRESH_TOKEN,
                'grant_type': 'refresh_token'
            })
            data = response.json()
            token_cache['access_token'] = data['access_token']
            token_cache['expiry'] = now + data['expires_in']
        finally:
            token_cache['refresh_promise'] = None

    token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
    token_cache['refresh_promise'].start()
    token_cache['refresh_promise'].join()
    return token_cache['access_token']

def get_location():
    current_seconds = time.localtime().tm_sec
    return 'europe-west1' if current_seconds < 30 else 'us-east5'

def construct_api_url(location):
    return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'

@app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
def handle_request():
    if request.method == 'OPTIONS':
        return handle_options()

    # 检查 x-api-key
    api_key = request.headers.get('x-api-key')
    if api_key != API_KEY:
        return jsonify({
            'type': 'error',
            'error': {
                'type': 'permission_error',
                'message': 'Your API key does not have permission to use the specified resource.'
            }
        }), 403

    access_token = get_access_token()
    location = get_location()
    model = request.headers.get('model', 'claude-3-5-sonnet@20240620')
    if model == 'claude-3-5-sonnet-20240620':
        model = 'claude-3-5-sonnet@20240620'

    api_url = construct_api_url(location)

    request_body = request.json
    if 'anthropic_version' in request_body:
        del request_body['anthropic_version']
    if 'model' in request_body:
        del request_body['model']
    request_body['anthropic_version'] = "vertex-2023-10-16"

    headers = {
        'Authorization': f'Bearer {access_token}',
        'Content-Type': 'application/json; charset=utf-8'
    }

    response = requests.post(api_url, headers=headers, json=request_body)
    return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])

@app.route('/', methods=['GET'])
def index():
    return "Vertex Claude API Proxy", 200

def handle_options():
    headers = {
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
        'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
    }
    return '', 204, headers

if __name__ == '__main__':
    app.run(port=8080)