stamp2vec / app.py
sadjava's picture
changed to pipelines
fd52b7f
raw
history blame
No virus
3.45 kB
import gradio as gr
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib
from pipelines.detection.yolo_v8 import Yolov8Pipeline
from pipelines.detection.yolo_stamp import YoloStampPipeline
from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline
from pipelines.feature_extraction.vae import VaePipeline
from pipelines.feature_extraction.vits8 import Vits8Pipeline
from utils import *
yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt')
yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt')
vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt')
vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt')
dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt')
def doc_predict(image, det_choice, seg_choice, emb_choice):
image = image.convert('RGB')
if det_choice == 'yolov8':
boxes = yolov8(image)
elif det_choice == 'yolo-stamp':
boxes = yolo_stamp(image)
else:
return
image_with_boxes = visualize_bbox(image, boxes)
segmented_stamps = []
for box in boxes:
cropped_stamp = image.crop(box.tolist())
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
widths, heights = zip(*(i.size for i in segmented_stamps))
total_width = sum(widths)
max_height = max(heights)
concatenated_stamps = Image.new('RGB', (total_width, max_height))
x_offset = 0
for im in segmented_stamps:
concatenated_stamps.paste(im, (x_offset,0))
x_offset += im.size[0]
embeddings = []
if emb_choice == 'vits8':
for stamp in segmented_stamps:
embeddings.append(vits8(stamp))
elif emb_choice == 'vae-encoder':
for stamp in segmented_stamps:
embeddings.append(vae(stamp))
embeddings = np.stack(embeddings)
similarities = cosine_similarity(embeddings)
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
fig, ax = plt.subplots()
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
cmap="YlGn", cbarlabel="Embeddings similarities")
texts = annotate_heatmap(im, valfmt="{x:.3f}")
return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig
doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']]
doc_inputs = [
gr.Image(label="Document image", type="pil"),
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
gr.Checkbox(label="Use segmentation model"),
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
]
doc_outputs = [
gr.Image(label="Document with bounding boxes", type="pil"),
gr.DataFrame(type='pandas', label="Bounding boxes"),
gr.Image(label="Segmented stamps", type="pil"),
gr.DataFrame(type='numpy', label="Embeddings"),
gr.Plot(label="Cosine Similarities")
]
with gr.Blocks() as demo:
with gr.Tab("Signle document"):
gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples)
demo.launch(inline=False)