vertex2api / app.py
smgc's picture
Update app.py
f21f2de verified
raw
history blame
No virus
4.34 kB
import os
import time
import json
import asyncio
import aiohttp
from aiohttp import web
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
token_cache = {
'access_token': '',
'expiry': 0,
'refresh_promise': None
}
async def get_access_token():
now = time.time()
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
return token_cache['access_token']
if token_cache['refresh_promise']:
await token_cache['refresh_promise']
return token_cache['access_token']
async def refresh_token():
async with aiohttp.ClientSession() as session:
async with session.post(TOKEN_URL, json={
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': REFRESH_TOKEN,
'grant_type': 'refresh_token'
}) as response:
data = await response.json()
token_cache['access_token'] = data['access_token']
token_cache['expiry'] = now + data['expires_in']
token_cache['refresh_promise'] = refresh_token()
await token_cache['refresh_promise']
token_cache['refresh_promise'] = None
return token_cache['access_token']
def get_location():
current_seconds = time.localtime().tm_sec
return 'europe-west1' if current_seconds < 30 else 'us-east5'
def construct_api_url(location):
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
async def handle_request(request):
if request.method == 'OPTIONS':
return handle_options()
api_key = request.headers.get('x-api-key')
if api_key != API_KEY:
error_response = web.Response(
text=json.dumps({
'type': 'error',
'error': {
'type': 'permission_error',
'message': 'Your API key does not have permission to use the specified resource.'
}
}),
status=403,
content_type='application/json'
)
error_response.headers['Access-Control-Allow-Origin'] = '*'
error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return error_response
access_token = await get_access_token()
location = get_location()
api_url = construct_api_url(location)
request_body = await request.json()
if 'anthropic_version' in request_body:
del request_body['anthropic_version']
if 'model' in request_body:
del request_body['model']
request_body['anthropic_version'] = 'vertex-2023-10-16'
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json; charset=utf-8'
}
async with aiohttp.ClientSession() as session:
async with session.post(api_url, json=request_body, headers=headers) as response:
response_body = await response.read()
response_headers = response.headers
response_status = response.status
modified_response = web.Response(
body=response_body,
status=response_status,
headers=response_headers
)
modified_response.headers['Access-Control-Allow-Origin'] = '*'
modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return modified_response
def handle_options():
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
}
return web.Response(status=204, headers=headers)
app = web.Application()
app.router.add_route('*', '/', handle_request)
if __name__ == '__main__':
web.run_app(app, port=8080)