sadjava commited on
Commit
1ec96d4
1 Parent(s): c41dec0

Fix utils.py

Browse files
Files changed (2) hide show
  1. app.py +13 -11
  2. utils.py +4 -0
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
  from sklearn.metrics.pairwise import cosine_similarity
5
  import pandas as pd
6
- from PIL import Image, ImageDraw
7
  import matplotlib.pyplot as plt
8
- import matplotlib
9
 
10
  from pipelines.detection.yolo_v8 import Yolov8Pipeline
11
  from pipelines.detection.yolo_stamp import YoloStampPipeline
@@ -41,17 +39,21 @@ def doc_predict(image, det_choice, seg_choice, emb_choice):
41
  cropped_stamp = image.crop(box.tolist())
42
  segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
43
 
44
- widths, heights = zip(*(i.size for i in segmented_stamps))
 
45
 
46
- total_width = sum(widths)
47
- max_height = max(heights)
48
 
49
- concatenated_stamps = Image.new('RGB', (total_width, max_height))
50
 
51
- x_offset = 0
52
- for im in segmented_stamps:
53
- concatenated_stamps.paste(im, (x_offset,0))
54
- x_offset += im.size[0]
 
 
 
55
 
56
  embeddings = []
57
  if emb_choice == 'vits8':
 
1
  import gradio as gr
 
2
  import numpy as np
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import pandas as pd
5
+ from PIL import Image
6
  import matplotlib.pyplot as plt
 
7
 
8
  from pipelines.detection.yolo_v8 import Yolov8Pipeline
9
  from pipelines.detection.yolo_stamp import YoloStampPipeline
 
39
  cropped_stamp = image.crop(box.tolist())
40
  segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
41
 
42
+ if len(segmented_stamps) != 0:
43
+ widths, heights = zip(*(i.size for i in segmented_stamps))
44
 
45
+ total_width = sum(widths)
46
+ max_height = max(heights)
47
 
48
+ concatenated_stamps = Image.new('RGB', (total_width, max_height))
49
 
50
+ x_offset = 0
51
+ for im in segmented_stamps:
52
+ concatenated_stamps.paste(im, (x_offset,0))
53
+ x_offset += im.size[0]
54
+
55
+ else:
56
+ concatenated_stamps = Image.new('RGB', (0, 0))
57
 
58
  embeddings = []
59
  if emb_choice == 'vits8':
utils.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  def heatmap(data, row_labels, col_labels, ax=None,
2
  cbar_kw=None, cbarlabel="", **kwargs):
3
  """
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+ from PIL import Image, ImageDraw
4
+
5
  def heatmap(data, row_labels, col_labels, ax=None,
6
  cbar_kw=None, cbarlabel="", **kwargs):
7
  """