Spaces:
Running
Running
import os | |
import time | |
import json | |
import asyncio | |
import aiohttp | |
from aiohttp import web | |
PROJECT_ID = os.getenv('PROJECT_ID') | |
CLIENT_ID = os.getenv('CLIENT_ID') | |
CLIENT_SECRET = os.getenv('CLIENT_SECRET') | |
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN') | |
API_KEY = os.getenv('API_KEY') | |
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token' | |
token_cache = { | |
'access_token': '', | |
'expiry': 0, | |
'refresh_promise': None | |
} | |
async def get_access_token(): | |
now = time.time() | |
if token_cache['access_token'] and now < token_cache['expiry'] - 120: | |
return token_cache['access_token'] | |
if token_cache['refresh_promise']: | |
await token_cache['refresh_promise'] | |
return token_cache['access_token'] | |
async def refresh_token(): | |
async with aiohttp.ClientSession() as session: | |
async with session.post(TOKEN_URL, json={ | |
'client_id': CLIENT_ID, | |
'client_secret': CLIENT_SECRET, | |
'refresh_token': REFRESH_TOKEN, | |
'grant_type': 'refresh_token' | |
}) as response: | |
data = await response.json() | |
token_cache['access_token'] = data['access_token'] | |
token_cache['expiry'] = now + data['expires_in'] | |
token_cache['refresh_promise'] = refresh_token() | |
await token_cache['refresh_promise'] | |
token_cache['refresh_promise'] = None | |
return token_cache['access_token'] | |
def get_location(): | |
current_seconds = time.localtime().tm_sec | |
return 'europe-west1' if current_seconds < 30 else 'us-east5' | |
def construct_api_url(location): | |
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict' | |
async def handle_request(request): | |
if request.method == 'OPTIONS': | |
return handle_options() | |
api_key = request.headers.get('x-api-key') | |
if api_key != API_KEY: | |
error_response = web.Response( | |
text=json.dumps({ | |
'type': 'error', | |
'error': { | |
'type': 'permission_error', | |
'message': 'Your API key does not have permission to use the specified resource.' | |
} | |
}), | |
status=403, | |
content_type='application/json' | |
) | |
error_response.headers['Access-Control-Allow-Origin'] = '*' | |
error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD' | |
error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
return error_response | |
access_token = await get_access_token() | |
location = get_location() | |
api_url = construct_api_url(location) | |
request_body = await request.json() | |
if 'anthropic_version' in request_body: | |
del request_body['anthropic_version'] | |
if 'model' in request_body: | |
del request_body['model'] | |
request_body['anthropic_version'] = 'vertex-2023-10-16' | |
headers = { | |
'Authorization': f'Bearer {access_token}', | |
'Content-Type': 'application/json; charset=utf-8' | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post(api_url, json=request_body, headers=headers) as response: | |
response_body = await response.read() | |
response_headers = response.headers | |
response_status = response.status | |
modified_response = web.Response( | |
body=response_body, | |
status=response_status, | |
headers=response_headers | |
) | |
modified_response.headers['Access-Control-Allow-Origin'] = '*' | |
modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS' | |
modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
return modified_response | |
def handle_options(): | |
headers = { | |
'Access-Control-Allow-Origin': '*', | |
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS', | |
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
} | |
return web.Response(status=204, headers=headers) | |
app = web.Application() | |
app.router.add_route('*', '/', handle_request) | |
if __name__ == '__main__': | |
web.run_app(app, port=8080) |