File size: 3,448 Bytes
479c88d
 
fd52b7f
479c88d
fd52b7f
 
 
 
479c88d
fd52b7f
 
 
 
 
479c88d
fd52b7f
479c88d
 
fd52b7f
 
 
 
 
479c88d
 
fd52b7f
479c88d
 
 
 
fd52b7f
479c88d
 
fd52b7f
479c88d
 
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479c88d
 
 
fd52b7f
 
479c88d
 
fd52b7f
 
479c88d
6350200
479c88d
 
 
 
 
 
 
 
 
fd52b7f
479c88d
 
fd52b7f
 
 
479c88d
fd52b7f
479c88d
 
fd52b7f
 
479c88d
fd52b7f
479c88d
b16ff26
479c88d
fd52b7f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib

from pipelines.detection.yolo_v8 import Yolov8Pipeline
from pipelines.detection.yolo_stamp import YoloStampPipeline
from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline
from pipelines.feature_extraction.vae import VaePipeline
from pipelines.feature_extraction.vits8 import Vits8Pipeline

from utils import *


yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt')
yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt')
vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt')
vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt')
dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt')


def doc_predict(image, det_choice, seg_choice, emb_choice):

    image = image.convert('RGB')
    
    if det_choice == 'yolov8':
        boxes = yolov8(image)
        
    elif det_choice == 'yolo-stamp':
        boxes = yolo_stamp(image)
    else:
        return
    image_with_boxes = visualize_bbox(image, boxes)

    segmented_stamps = []
    for box in boxes:
        cropped_stamp = image.crop(box.tolist())
        segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)

    widths, heights = zip(*(i.size for i in segmented_stamps))

    total_width = sum(widths)
    max_height = max(heights)

    concatenated_stamps = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for im in segmented_stamps:
        concatenated_stamps.paste(im, (x_offset,0))
        x_offset += im.size[0]

    embeddings = []
    if emb_choice == 'vits8':
        for stamp in segmented_stamps:
            embeddings.append(vits8(stamp))

    elif emb_choice == 'vae-encoder':
        for stamp in segmented_stamps:
            embeddings.append(vae(stamp))
    
    embeddings = np.stack(embeddings)

    similarities = cosine_similarity(embeddings)

    df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
    
    fig, ax = plt.subplots()
    im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
                    cmap="YlGn", cbarlabel="Embeddings similarities")
    texts = annotate_heatmap(im, valfmt="{x:.3f}")
    return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig


doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']]
doc_inputs = [
    gr.Image(label="Document image", type="pil"),
    gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
    gr.Checkbox(label="Use segmentation model"),
    gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
]
doc_outputs = [
    gr.Image(label="Document with bounding boxes", type="pil"),
    gr.DataFrame(type='pandas', label="Bounding boxes"),
    gr.Image(label="Segmented stamps", type="pil"),
    gr.DataFrame(type='numpy', label="Embeddings"),
    gr.Plot(label="Cosine Similarities")
]

with gr.Blocks() as demo:
    with gr.Tab("Signle document"): 
        gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples)

demo.launch(inline=False)