ucaslcl commited on
Commit
504cd76
1 Parent(s): b482c60

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +72 -74
modeling_GOT.py CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False):
488
 
489
  self.disable_torch_init()
490
 
@@ -575,87 +575,86 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
575
  )
576
 
577
 
578
- # if render:
579
- # print('==============rendering===============')
 
580
 
581
- # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
582
 
583
- # if outputs.endswith(stop_str):
584
- # outputs = outputs[:-len(stop_str)]
585
- # outputs = outputs.strip()
586
-
587
- # if '**kern' in outputs:
588
- # import verovio
589
- # from cairosvg import svg2png
590
- # import cv2
591
- # import numpy as np
592
- # tk = verovio.toolkit()
593
- # tk.loadData(outputs)
594
- # tk.setOptions({"pageWidth": 2100, "footer": 'none',
595
- # 'barLineWidth': 0.5, 'beamMaxSlope': 15,
596
- # 'staffLineWidth': 0.2, 'spacingStaff': 6})
597
- # tk.getPageCount()
598
- # svg = tk.renderToSVG()
599
- # svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
600
-
601
- # svg_to_html(svg, "./results/demo.html")
602
-
603
- # if ocr_type == 'format' and '**kern' not in outputs:
604
 
605
 
606
- # if '\\begin{tikzpicture}' not in outputs:
607
- # html_path = "./render_tools/" + "/content-mmd-to-html.html"
608
- # html_path_2 = "./results/demo.html"
609
- # right_num = outputs.count('\\right')
610
- # left_num = outputs.count('\left')
611
 
612
- # if right_num != left_num:
613
- # outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
614
 
615
 
616
- # outputs = outputs.replace('"', '``').replace('$', '')
617
 
618
- # outputs_list = outputs.split('\n')
619
- # gt= ''
620
- # for out in outputs_list:
621
- # gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
622
 
623
- # gt = gt[:-2]
624
-
625
- # with open(html_path, 'r') as web_f:
626
- # lines = web_f.read()
627
- # lines = lines.split("const text =")
628
- # new_web = lines[0] + 'const text =' + gt + lines[1]
629
- # else:
630
- # html_path = "./render_tools/" + "/tikz.html"
631
- # html_path_2 = "./results/demo.html"
632
- # outputs = outputs.translate(translation_table)
633
- # outputs_list = outputs.split('\n')
634
- # gt= ''
635
- # for out in outputs_list:
636
- # if out:
637
- # if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
638
- # while out[-1] == ' ':
639
- # out = out[:-1]
640
- # if out is None:
641
- # break
642
 
643
- # if out:
644
- # if out[-1] != ';':
645
- # gt += out[:-1] + ';\n'
646
- # else:
647
- # gt += out + '\n'
648
- # else:
649
- # gt += out + '\n'
650
 
651
 
652
- # with open(html_path, 'r') as web_f:
653
- # lines = web_f.read()
654
- # lines = lines.split("const text =")
655
- # new_web = lines[0] + gt + lines[1]
656
 
657
- # with open(html_path_2, 'w') as web_f_new:
658
- # web_f_new.write(new_web)
659
 
660
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
661
 
@@ -807,13 +806,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
807
 
808
  if render:
809
  print('==============rendering===============')
 
810
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
811
 
812
  if outputs.endswith(stop_str):
813
  outputs = outputs[:-len(stop_str)]
814
  outputs = outputs.strip()
815
 
816
- html_path = "./render_tools/" + "content-mmd-to-html.html"
817
  html_path_2 = save_render_file
818
  right_num = outputs.count('\\right')
819
  left_num = outputs.count('\left')
@@ -831,10 +830,9 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
831
 
832
  gt = gt[:-2]
833
 
834
- with smart_open(html_path, 'r') as web_f:
835
- lines = web_f.read()
836
- lines = lines.split("const text =")
837
- new_web = lines[0] + 'const text =' + gt + lines[1]
838
 
839
  with smart_open(html_path_2, 'w') as web_f_new:
840
  web_f_new.write(new_web)
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
488
 
489
  self.disable_torch_init()
490
 
 
575
  )
576
 
577
 
578
+ if render:
579
+ print('==============rendering===============')
580
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
581
 
582
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
583
 
584
+ if outputs.endswith(stop_str):
585
+ outputs = outputs[:-len(stop_str)]
586
+ outputs = outputs.strip()
587
+
588
+ if '**kern' in outputs:
589
+ import verovio
590
+ from cairosvg import svg2png
591
+ import cv2
592
+ import numpy as np
593
+ tk = verovio.toolkit()
594
+ tk.loadData(outputs)
595
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
596
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
597
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
598
+ tk.getPageCount()
599
+ svg = tk.renderToSVG()
600
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
601
+
602
+ svg_to_html(svg, save_render_file)
603
+
604
+ if ocr_type == 'format' and '**kern' not in outputs:
605
 
606
 
607
+ if '\\begin{tikzpicture}' not in outputs:
608
+ html_path_2 = save_render_file
609
+ right_num = outputs.count('\\right')
610
+ left_num = outputs.count('\left')
 
611
 
612
+ if right_num != left_num:
613
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
614
 
615
 
616
+ outputs = outputs.replace('"', '``').replace('$', '')
617
 
618
+ outputs_list = outputs.split('\n')
619
+ gt= ''
620
+ for out in outputs_list:
621
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
622
 
623
+ gt = gt[:-2]
624
+
625
+
626
+ lines = content_mmd_to_html
627
+ lines = lines.split("const text =")
628
+ new_web = lines[0] + 'const text =' + gt + lines[1]
629
+
630
+ else:
631
+ html_path_2 = save_render_file
632
+ outputs = outputs.translate(translation_table)
633
+ outputs_list = outputs.split('\n')
634
+ gt= ''
635
+ for out in outputs_list:
636
+ if out:
637
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
638
+ while out[-1] == ' ':
639
+ out = out[:-1]
640
+ if out is None:
641
+ break
642
 
643
+ if out:
644
+ if out[-1] != ';':
645
+ gt += out[:-1] + ';\n'
646
+ else:
647
+ gt += out + '\n'
648
+ else:
649
+ gt += out + '\n'
650
 
651
 
652
+ lines = tik_html
653
+ lines = lines.split("const text =")
654
+ new_web = lines[0] + gt + lines[1]
 
655
 
656
+ with smart_open(html_path_2, 'w') as web_f_new:
657
+ web_f_new.write(new_web)
658
 
659
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
660
 
 
806
 
807
  if render:
808
  print('==============rendering===============')
809
+ from .render_tools import content_mmd_to_html
810
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
811
 
812
  if outputs.endswith(stop_str):
813
  outputs = outputs[:-len(stop_str)]
814
  outputs = outputs.strip()
815
 
 
816
  html_path_2 = save_render_file
817
  right_num = outputs.count('\\right')
818
  left_num = outputs.count('\left')
 
830
 
831
  gt = gt[:-2]
832
 
833
+ lines = content_mmd_to_html
834
+ lines = lines.split("const text =")
835
+ new_web = lines[0] + 'const text =' + gt + lines[1]
 
836
 
837
  with smart_open(html_path_2, 'w') as web_f_new:
838
  web_f_new.write(new_web)