ucaslcl commited on
Commit
462ad59
1 Parent(s): e7c2934

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +26 -26
modeling_GOT.py CHANGED
@@ -12,7 +12,7 @@ from .got_vision_b import build_GOT_vit_b
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
  import dataclasses
15
-
16
 
17
  DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
@@ -715,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
715
  return processed_images
716
 
717
 
718
- def chat_crop(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, multi_page=False):
719
  # Model
720
  self.disable_torch_init()
721
 
@@ -805,36 +805,36 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
805
  stopping_criteria=[stopping_criteria]
806
  )
807
 
808
- # if render:
809
- # print('==============rendering===============')
810
- # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
811
 
812
- # if outputs.endswith(stop_str):
813
- # outputs = outputs[:-len(stop_str)]
814
- # outputs = outputs.strip()
815
 
816
- # html_path = "./render_tools/" + "/content-mmd-to-html.html"
817
- # html_path_2 = "./results/demo.html"
818
- # right_num = outputs.count('\\right')
819
- # left_num = outputs.count('\left')
820
 
821
- # if right_num != left_num:
822
- # outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
823
 
824
 
825
- # outputs = outputs.replace('"', '``').replace('$', '')
826
 
827
- # outputs_list = outputs.split('\n')
828
- # gt= ''
829
- # for out in outputs_list:
830
- # gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
831
 
832
- # gt = gt[:-2]
833
 
834
- # with open(html_path, 'r') as web_f:
835
- # lines = web_f.read()
836
- # lines = lines.split("const text =")
837
- # new_web = lines[0] + 'const text =' + gt + lines[1]
838
 
839
- # with open(html_path_2, 'w') as web_f_new:
840
- # web_f_new.write(new_web)
 
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
  import dataclasses
15
+ from megfile import smart_open
16
 
17
  DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
 
715
  return processed_images
716
 
717
 
718
+ def chat_plus(self, tokenizer, image_file, render=False, save_render_file=None, multi_page=False):
719
  # Model
720
  self.disable_torch_init()
721
 
 
805
  stopping_criteria=[stopping_criteria]
806
  )
807
 
808
+ if render:
809
+ print('==============rendering===============')
810
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
811
 
812
+ if outputs.endswith(stop_str):
813
+ outputs = outputs[:-len(stop_str)]
814
+ outputs = outputs.strip()
815
 
816
+ html_path = "./render_tools/" + "content-mmd-to-html.html"
817
+ html_path_2 = save_render_file
818
+ right_num = outputs.count('\\right')
819
+ left_num = outputs.count('\left')
820
 
821
+ if right_num != left_num:
822
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
823
 
824
 
825
+ outputs = outputs.replace('"', '``').replace('$', '')
826
 
827
+ outputs_list = outputs.split('\n')
828
+ gt= ''
829
+ for out in outputs_list:
830
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
831
 
832
+ gt = gt[:-2]
833
 
834
+ with smart_open(html_path, 'r') as web_f:
835
+ lines = web_f.read()
836
+ lines = lines.split("const text =")
837
+ new_web = lines[0] + 'const text =' + gt + lines[1]
838
 
839
+ with smart_open(html_path_2, 'w') as web_f_new:
840
+ web_f_new.write(new_web)