File size: 14,809 Bytes
16aad69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c97e98
16aad69
 
 
 
9c97e98
16aad69
 
02acfd5
16aad69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import streamlit as st
from predict import PaddleOCR
from pdf2image import convert_from_bytes
import cv2
import PIL
import numpy as np
import os
import tempfile
import random
import string
from ultralyticsplus import YOLO
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
import re
from dateutil.parser import parse
import datetime
from file_utils import (
    get_img,
    save_excel_file,
    concat_csv,
    convert_pdf_to_image,
    filter_color,
    plot,
    delete_file,
)
from process import (
    filter_columns,
    extract_text_of_col,
    prepare_cols,
    process_cols,
    finalize_data,
)


table_model = YOLO("table.pt")
column_model = YOLO("columns.pt")

def remove_dots(string):
    # Remove dots from the first and last position of the string
    string = string.strip('.')
    
    # Remove the first dot from left to right if there are still more than one dots
    if string.count('.') > 1:
        string = string.replace(".", "", 1)
        
    return string

def convert_df(df):
   return df.to_csv(index=False).encode('utf-8')


def PIL_to_cv(pil_img):
    return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)


def cv_to_PIL(cv_img):
    return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))

def visualize_ocr(pil_img, ocr_result):
    plt.imshow(pil_img, interpolation='lanczos')
    plt.gcf().set_size_inches(20, 20)
    ax = plt.gca()

    for idx, result in enumerate(ocr_result):
        bbox = result['bbox']
        text = result['text']
        rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-')
        ax.add_patch(rect)
        ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7)

    plt.xticks([], [])
    plt.yticks([], [])

    plt.gcf().set_size_inches(10, 10)
    plt.axis('off')
    img_buf = io.BytesIO()
    plt.savefig(img_buf, bbox_inches='tight', dpi=150)
    plt.close()

    return PIL.Image.open(img_buf)

def filter_columns(columns: np.ndarray):
    for idx, col in enumerate(columns):
        if idx >= len(columns) - 1:
            break
        nxt = columns[idx + 1]
        threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2
        if (col[2] - columns[idx + 1][0]) > threshold * 0.5:
            col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3])
            columns = np.delete(columns, idx + 1, 0)
            idx -= 1
    return columns

st.title("Extract Data from Bank Statements")

model = PaddleOCR()

uploaded = st.file_uploader(
    "Upload a bank statement pdf file",
    type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"],
)
number = st.number_input('Select Year',value=2023, step=1)
filter = st.checkbox("filter color")
if st.button('Analyze image'):

        final_csv = pd.DataFrame()  
        first_flag_dataframe=0
        if uploaded is None:
            st.write('Please upload an image')

        else:
            tabs = st.tabs(
                ['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
            )
            print(uploaded.type)
            if uploaded.type == "application/pdf":
                 foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
                 filename = uploaded.name.split(".")[0]
                 pdf_pages=convert_from_bytes(uploaded.read(),500)
                 for page_enumeration, page in enumerate(pdf_pages, start=1):

                    with tabs[0]:
                        st.header('Pages : '+str(page_enumeration))
                        st.image(page)

                    page_img=np.asarray(page)
                    tables = PaddleOCR.table_model(page_img, conf=0.75)
                    tabel_datas=tables[0].boxes.data.cpu().numpy()

                    tables = tables[0].boxes.xyxy.cpu().numpy()
                    with tabs[1]:
                        st.header('Table Detection Page :'+str(page_enumeration))

                        str_cols = st.columns(4)
                        str_cols[0].subheader('Table image')
                        str_cols[1].subheader('Columns')
                        str_cols[2].subheader('Structure result')
                        str_cols[3].subheader('Cells result')
                    results = []
                    for table in tables:
                        try:
                            
                            tabel_data = np.array(
                                        sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray
                                    )
                            
                            tabel_data = filter_columns(tabel_data)

                            str_cols[0].image(plot(page_img, tabel_data), channels="RGB")
                            # * crop the table as an image from the original image
                            sub_img = page_img[
                                int(table[1].item()): int(table[3].item()),
                                int(table[0].item()): int(table[2].item()),
                                ]
    
                            columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
                            cols_data = columns_detect[0].boxes.data.cpu().numpy()

                            # * Sort columns according to the x coordinate
                            cols_data = np.array(
                                sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
                            )
                            
                            # * merge the duplicated columns
                            cols_data = filter_columns(cols_data)                           
                            str_cols[1].image(plot(sub_img, cols_data), channels="RGB")
                            
                        except Exception as e:
                            print(e)
                            st.warning("No Detection")
                        try:
                            ####################################################################

                            # # columns = cols_data[:, 0:4]
                            # # #sub_imgs = []
                            # # thr = 0
                            # # column = columns[0]
                            # # maxcol1=int(column[1])
                            # # maxcol3=int(column[3])
                            # # cols = []
                            # # for column in columns:
                            # #     if maxcol1 < int(column[1]) :
                            # #         maxcol1=int(column[1])
                            # #     if maxcol3 < int(column[3]) :
                            # #         maxcol3=int(column[3])

                            # # sub_imgs = (sub_img[ maxcol1: maxcol3, : ])
                            # # str_cols[2].image(sub_imgs)
                            # # image = filter_color(sub_imgs)
                            # # res, threshold,ocr_res = extract_text_of_col(image)
                            # # vis_ocr_img = visualize_ocr(image, ocr_res)
                            # # str_cols[3].image(vis_ocr_img)
                            # # thr += threshold
                            # # cols.append(prepare_cols(res, threshold * 0.6))
                            # # print("cols : ",cols)
                            # # thr = thr / len(columns)
                            # # data = process_cols(cols, thr * 0.6)
                            # # print("data : ",data)
                            ######################################################################
                            columns = cols_data[:, 0:4]
                            sub_imgs = []
                            column = columns[0]
                            maxcol1=int(column[1])
                            maxcol3=int(column[3])
                            #for column in columns:
                            #    if maxcol1 < int(column[1]) :
                            #        maxcol1=int(column[1])
                            #    if maxcol3 < int(column[3]) :
                            #        maxcol3=int(column[3])

                            for column in columns:
                                # * Create list of cropped images for each column
                                sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])])
                            cols = []
                            thr = 0
                            for image in sub_imgs:
                                if filter:
                                    # * keep only black color in the image
                                    image = filter_color(image)

                                # * extract text of each column and get the length threshold
                                res, threshold, ocr_res = extract_text_of_col(image)
                                thr += threshold

                                # * arrange the rows of each column with respect to row length threshold
                                cols.append(prepare_cols(res, threshold * 0.6))

                            thr = thr / len(sub_imgs)

                            # * append each element in each column to its right place in the dataframe
                            data = process_cols(cols, thr * 0.6)

                            # * merge the related rows together
                            
                            data: pd.DataFrame = finalize_data(data, page_enumeration)
                            results.append(data)
                            with tabs[2]:
                                st.header('Extracted Table(s)')
                                st.dataframe(data)
                            print("data : ",data)
                            print("results : ", results)
                            if first_flag_dataframe == 0 :
                                first_flag_dataframe=1
                                final_csv=data
                            else:
                                final_csv = pd.concat([final_csv,data],ignore_index=True)
                                csv = convert_df(data)
                                print(csv)

                        except:
                            st.warning("Text Extraction Failed")
                            continue
                 with tabs[3]:
                     st.dataframe(final_csv)
                     rough_csv= convert_df(final_csv)
                     st.download_button(
                        "rough-csv",
                        rough_csv,
                        "file.csv",
                        "text/csv",
                        key='rough-csv'
                        )   
                 final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance']
                 #final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'})
                 final_csv['Date'] = final_csv['Date'].astype(str)
                 st.dataframe(final_csv)
                 final_csv = final_csv[~final_csv['Date'].str.contains('Date')]
                 final_csv = final_csv[~final_csv['Date'].str.contains('日期')]
                 final_csv = final_csv[~final_csv['Date'].str.contains('口期')]
                 final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x))
                 final_csv['Date'] = final_csv['Date'].apply(lambda x: x + str(number))
                 final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True))
                 #final_csv['Date']=final_csv['Date'].str.replace(' ', '')
                 final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y')
                 final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str)
                 final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
                 final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots)
                 final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1
                 final_csv['Deposit'] = final_csv['Deposit'].astype(str)
                 final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
                 final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots)
                 final_csv['Deposit'] = final_csv['Deposit'].astype(float)
                 final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0)
                 final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1)
                 final_csv['Payee'] = ''
                 final_csv['Description'] = final_csv['Transaction_Details']
                 final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three']
                 final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1)
                 final_csv['Reference'] = ''
                 final_csv['Check Number'] = ''
                 df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']]
                 df = df[df['*Amount'] != 0]
                 csv = convert_df(df)
                 st.dataframe(df)
                 st.download_button(
                    "Press to Download",
                    csv,
                    "file.csv",
                    "text/csv",
                    key='download-csv'
                    )   

            #success = st.button("Extract", on_click=model, args=[uploaded, filter])