vision / app.py
Edmond7's picture
Update app.py
aa732b3 verified
raw
history blame contribute delete
No virus
2.4 kB
import os
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
from fastapi.security.api_key import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import base64
import io
app = FastAPI()
# Load the processor and model
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
# API Key setup
API_KEY = os.environ.get("API_KEY")
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(api_key_header: str = Depends(api_key_header)):
if api_key_header == API_KEY:
return api_key_header
else:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
def process_image_and_text(image, text):
inputs = processor.process(
images=[image],
text=text
)
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
class Base64Request(BaseModel):
image: str
text: str
@app.post("/upload")
async def upload_image(file: UploadFile = File(...), text: str = "", api_key: str = Depends(get_api_key)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
response = process_image_and_text(image, text)
return {"response": response}
@app.post("/base64")
async def process_base64(request: Base64Request, api_key: str = Depends(get_api_key)):
try:
image_data = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_data))
except:
raise HTTPException(status_code=400, detail="Invalid base64 image")
response = process_image_and_text(image, request.text)
return {"response": response}