File size: 3,692 Bytes
479c88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05aeeac
479c88d
 
 
 
 
 
 
85b4b37
479c88d
 
 
 
05aeeac
479c88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6350200
479c88d
6350200
479c88d
 
 
 
 
 
 
 
 
 
 
 
 
2fa1032
479c88d
 
 
 
 
 
 
 
 
b16ff26
479c88d
 
a5018b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import numpy as np
from ultralytics import YOLO
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from utils import *
from models import YOLOStamp, Encoder

device = 'cuda' if torch.cuda.is_available() else 'cpu'


yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')

yolo_stamp = YOLOStamp()
yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu'))
yolo_stamp = yolo_stamp.to(device)
yolo_stamp.eval()
transform = A.Compose([
    A.Normalize(),
    ToTensorV2(p=1.0),
])

vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu')
vits8 = vits8.to(device)
vits8.eval()

encoder = Encoder()
encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth'), map_location='cpu'))
encoder = encoder.to(device)
encoder.eval()


def predict(image, det_choice, emb_choice):

    shape = torch.tensor(image.size)
    image = image.convert('RGB')
    
    if det_choice == 'yolov8':
        coef = torch.hstack((shape, shape)) / 640
        image = image.resize((640, 640))
        boxes = yolov8(image)[0].boxes.xyxy.cpu()
        image_with_boxes = visualize_bbox(image, boxes)
        
    elif det_choice == 'yolo-stamp':
        coef = torch.hstack((shape, shape)) / 448
        image = image.resize((448, 448))
        image_tensor = transform(image=np.array(image))['image']
        output = yolo_stamp(image_tensor.unsqueeze(0).to(device))

        boxes = output_tensor_to_boxes(output[0].detach().cpu())
        boxes = nonmax_suppression(boxes)
        boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
        image_with_boxes = visualize_bbox(image, boxes)
    else:
        return
    

    embeddings = []
    if emb_choice == 'vits8':
        for box in boxes:
            cropped_stamp = to_tensor(image.crop(box.tolist()))
            embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())

    elif emb_choice == 'vae-encoder':
        for box in boxes:
            cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118)))
            embeddings.append(np.array(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu()))
    
    embeddings = np.stack(embeddings)

    similarities = cosine_similarity(embeddings)

    boxes = boxes * coef
    df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
    
    fig, ax = plt.subplots()
    im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
                    cmap="YlGn", cbarlabel="Embeddings similarities")
    texts = annotate_heatmap(im, valfmt="{x:.3f}")
    return image_with_boxes, df_boxes, embeddings, fig


examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']]
inputs = [
    gr.Image(type="pil"),
    gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
    gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
]
outputs = [
    gr.Image(type="pil"),
    gr.DataFrame(type='pandas', label="Bounding boxes"),
    gr.DataFrame(type='numpy', label="Embeddings"),
    gr.Plot(label="Cosine Similarities")
]
app = gr.Interface(predict, inputs, outputs, examples=examples)
app.launch()