from fastapi import FastAPI, HTTPException from pydantic import BaseModel import cv2 import numpy as np from base64 import b64decode from sklearn.preprocessing import Normalizer from sklearn.metrics.pairwise import cosine_similarity import torch # PyTorch for ArcFace from torchvision import transforms from PIL import Image app = FastAPI() class ImageData(BaseModel): image1: str image2: str # Load the ArcFace model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True) # Remove the last layer (classification layer) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.to(device) model.eval() # Set to evaluation mode def preprocess_image(image): """Preprocesses the image for ArcFace.""" transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(112), # Typical ArcFace input size transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) return transform(pil_image).unsqueeze(0).to(device) def get_face_embedding(face): """Generates a face embedding using ArcFace.""" with torch.no_grad(): face_tensor = preprocess_image(face) embedding = model(face_tensor) return embedding.cpu().numpy().flatten() # Load the face detection model (Haar Cascade) face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') def extract_faces(image): """ Detects and extracts faces from an image. Args: image: The input image. Returns: A list of tuples, where each tuple contains: - The extracted face image as a NumPy array. - The bounding box coordinates of the face (x, y, w, h). """ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) face_list = [] for (x, y, w, h) in faces: face_img = image[y:y+h, x:x+w] face_list.append((face_img, (x, y, w, h))) return face_list @app.post("/verify") async def verify(data: ImageData): try: # Access the Base64 encoded strings from the request body image1 = data.image1.split(",")[1] # Remove data URL prefix image2 = data.image2.split(",")[1] # Remove data URL prefix # Decode the Base64 encoded string to OpenCV image id_document_image = cv2.imdecode(np.frombuffer(b64decode(image1), np.uint8), -1) live_image = cv2.imdecode(np.frombuffer(b64decode(image2), np.uint8), -1) # Extract face from ID document id_faces = extract_faces(id_document_image) if len(id_faces) == 0: return {"message": "No face detected in the ID document."} id_face, _ = id_faces[0] id_face_embedding = get_face_embedding(id_face) # Detect face in the live image live_faces = face_cascade.detectMultiScale(cv2.cvtColor(live_image, cv2.COLOR_BGR2GRAY), scaleFactor=1.4, minNeighbors=5, minSize=(30, 30)) if len(live_faces) == 0: return {"message": "No face detected in the live image."} # Process the first detected face in the live image for (x, y, w, h) in live_faces: live_face = live_image[y:y+h, x:x+w] live_face_embedding = get_face_embedding(live_face) # Normalize embeddings normalizer = Normalizer() id_face_embedding_normalized = normalizer.transform(id_face_embedding.reshape(1, -1)) live_face_embedding_normalized = normalizer.transform(live_face_embedding.reshape(1, -1)) # Compute cosine similarity between ID face and live face similarity = cosine_similarity(id_face_embedding_normalized, live_face_embedding_normalized) # Convert similarity to a standard Python float similarity_float = float(similarity[0][0]) if similarity_float > 0.6: # You might need to adjust this threshold return {"message": "Face matched!", "similarity": similarity_float} else: return {"message": "Face did not match.", "similarity": similarity_float} except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")