File size: 4,339 Bytes
3fdf5e9
 
f21f2de
 
 
 
3fdf5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f21f2de
3fdf5e9
 
 
 
 
 
f21f2de
3fdf5e9
 
f21f2de
 
 
3fdf5e9
 
 
 
f21f2de
 
 
 
 
 
 
 
3fdf5e9
 
 
 
 
 
 
 
 
f21f2de
3fdf5e9
 
 
 
 
f21f2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fdf5e9
 
 
f21f2de
 
3fdf5e9
 
 
 
f21f2de
 
3fdf5e9
 
 
 
 
 
f21f2de
 
 
 
 
 
 
 
 
 
 
 
 
 
3fdf5e9
f21f2de
3fdf5e9
 
 
 
 
 
 
f21f2de
 
 
 
3fdf5e9
 
f21f2de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import time
import json
import asyncio
import aiohttp
from aiohttp import web

PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')

TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'

token_cache = {
    'access_token': '',
    'expiry': 0,
    'refresh_promise': None
}

async def get_access_token():
    now = time.time()

    if token_cache['access_token'] and now < token_cache['expiry'] - 120:
        return token_cache['access_token']

    if token_cache['refresh_promise']:
        await token_cache['refresh_promise']
        return token_cache['access_token']

    async def refresh_token():
        async with aiohttp.ClientSession() as session:
            async with session.post(TOKEN_URL, json={
                'client_id': CLIENT_ID,
                'client_secret': CLIENT_SECRET,
                'refresh_token': REFRESH_TOKEN,
                'grant_type': 'refresh_token'
            }) as response:
                data = await response.json()
                token_cache['access_token'] = data['access_token']
                token_cache['expiry'] = now + data['expires_in']

    token_cache['refresh_promise'] = refresh_token()
    await token_cache['refresh_promise']
    token_cache['refresh_promise'] = None
    return token_cache['access_token']

def get_location():
    current_seconds = time.localtime().tm_sec
    return 'europe-west1' if current_seconds < 30 else 'us-east5'

def construct_api_url(location):
    return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'

async def handle_request(request):
    if request.method == 'OPTIONS':
        return handle_options()

    api_key = request.headers.get('x-api-key')
    if api_key != API_KEY:
        error_response = web.Response(
            text=json.dumps({
                'type': 'error',
                'error': {
                    'type': 'permission_error',
                    'message': 'Your API key does not have permission to use the specified resource.'
                }
            }),
            status=403,
            content_type='application/json'
        )
        error_response.headers['Access-Control-Allow-Origin'] = '*'
        error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
        error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
        return error_response

    access_token = await get_access_token()
    location = get_location()
    api_url = construct_api_url(location)

    request_body = await request.json()

    if 'anthropic_version' in request_body:
        del request_body['anthropic_version']
    if 'model' in request_body:
        del request_body['model']

    request_body['anthropic_version'] = 'vertex-2023-10-16'

    headers = {
        'Authorization': f'Bearer {access_token}',
        'Content-Type': 'application/json; charset=utf-8'
    }

    async with aiohttp.ClientSession() as session:
        async with session.post(api_url, json=request_body, headers=headers) as response:
            response_body = await response.read()
            response_headers = response.headers
            response_status = response.status

    modified_response = web.Response(
        body=response_body,
        status=response_status,
        headers=response_headers
    )
    modified_response.headers['Access-Control-Allow-Origin'] = '*'
    modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
    modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'

    return modified_response

def handle_options():
    headers = {
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
        'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
    }
    return web.Response(status=204, headers=headers)

app = web.Application()
app.router.add_route('*', '/', handle_request)

if __name__ == '__main__':
    web.run_app(app, port=8080)