import gradio as gr import numpy as np from sklearn.metrics.pairwise import cosine_similarity import pandas as pd from PIL import Image import matplotlib.pyplot as plt from pipelines.detection.yolo_v8 import Yolov8Pipeline from pipelines.detection.yolo_stamp import YoloStampPipeline from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline from pipelines.feature_extraction.vae import VaePipeline from pipelines.feature_extraction.vits8 import Vits8Pipeline from utils import * yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt') yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt') vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt') vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt') dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt') def doc_predict(image, det_choice, seg_choice, emb_choice): image = image.convert('RGB') if det_choice == 'yolov8': boxes = yolov8(image) elif det_choice == 'yolo-stamp': boxes = yolo_stamp(image) else: return image_with_boxes = visualize_bbox(image, boxes) segmented_stamps = [] for box in boxes: cropped_stamp = image.crop(box.tolist()) segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp) if len(segmented_stamps) != 0: widths, heights = zip(*(i.size for i in segmented_stamps)) total_width = sum(widths) max_height = max(heights) concatenated_stamps = Image.new('RGB', (total_width, max_height)) x_offset = 0 for im in segmented_stamps: concatenated_stamps.paste(im, (x_offset,0)) x_offset += im.size[0] else: concatenated_stamps = Image.new('RGB', (0, 0)) embeddings = [] if emb_choice == 'vits8': for stamp in segmented_stamps: embeddings.append(vits8(stamp)) elif emb_choice == 'vae-encoder': for stamp in segmented_stamps: embeddings.append(vae(stamp)) embeddings = np.stack(embeddings) similarities = cosine_similarity(embeddings) df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2']) fig, ax = plt.subplots() im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax, cmap="YlGn", cbarlabel="Embeddings similarities") texts = annotate_heatmap(im, valfmt="{x:.3f}") return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']] doc_inputs = [ gr.Image(label="Document image", type="pil"), gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'), gr.Checkbox(label="Use segmentation model"), gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'), ] doc_outputs = [ gr.Image(label="Document with bounding boxes", type="pil"), gr.DataFrame(type='pandas', label="Bounding boxes"), gr.Image(label="Segmented stamps", type="pil"), gr.DataFrame(type='numpy', label="Embeddings"), gr.Plot(label="Cosine Similarities") ] with gr.Blocks() as demo: with gr.Tab("Signle document"): gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples) demo.launch(inline=False)