File size: 2,404 Bytes
aa732b3
 
 
 
 
259d504
 
 
aa732b3
 
 
 
259d504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa732b3
 
 
 
 
 
 
 
 
 
 
53afe5a
259d504
 
aa732b3
259d504
 
 
 
 
 
 
 
 
 
 
 
aa732b3
 
 
259d504
aa732b3
 
 
 
259d504
aa732b3
259d504
aa732b3
 
 
 
 
 
 
259d504
aa732b3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
from fastapi.security.api_key import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import base64
import io

app = FastAPI()

# Load the processor and model
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# API Key setup
API_KEY = os.environ.get("API_KEY")
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)

async def get_api_key(api_key_header: str = Depends(api_key_header)):
    if api_key_header == API_KEY:
        return api_key_header
    else:
        raise HTTPException(
            status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
        )

def process_image_and_text(image, text):
    inputs = processor.process(
        images=[image],
        text=text
    )
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
    output = model.generate_from_batch(
        inputs,
        GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
        tokenizer=processor.tokenizer
    )
    generated_tokens = output[0, inputs['input_ids'].size(1):]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return generated_text

class Base64Request(BaseModel):
    image: str
    text: str

@app.post("/upload")
async def upload_image(file: UploadFile = File(...), text: str = "", api_key: str = Depends(get_api_key)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    response = process_image_and_text(image, text)
    return {"response": response}

@app.post("/base64")
async def process_base64(request: Base64Request, api_key: str = Depends(get_api_key)):
    try:
        image_data = base64.b64decode(request.image)
        image = Image.open(io.BytesIO(image_data))
    except:
        raise HTTPException(status_code=400, detail="Invalid base64 image")
    
    response = process_image_and_text(image, request.text)
    return {"response": response}