Update modeling_GOT.py
Browse files- modeling_GOT.py +26 -26
modeling_GOT.py
CHANGED
@@ -12,7 +12,7 @@ from .got_vision_b import build_GOT_vit_b
|
|
12 |
from torchvision import transforms
|
13 |
from torchvision.transforms.functional import InterpolationMode
|
14 |
import dataclasses
|
15 |
-
|
16 |
|
17 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
@@ -715,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
715 |
return processed_images
|
716 |
|
717 |
|
718 |
-
def
|
719 |
# Model
|
720 |
self.disable_torch_init()
|
721 |
|
@@ -805,36 +805,36 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
805 |
stopping_criteria=[stopping_criteria]
|
806 |
)
|
807 |
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
|
821 |
-
|
822 |
-
|
823 |
|
824 |
|
825 |
-
|
826 |
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
|
832 |
-
|
833 |
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
|
839 |
-
|
840 |
-
|
|
|
12 |
from torchvision import transforms
|
13 |
from torchvision.transforms.functional import InterpolationMode
|
14 |
import dataclasses
|
15 |
+
from megfile import smart_open
|
16 |
|
17 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
|
|
715 |
return processed_images
|
716 |
|
717 |
|
718 |
+
def chat_plus(self, tokenizer, image_file, render=False, save_render_file=None, multi_page=False):
|
719 |
# Model
|
720 |
self.disable_torch_init()
|
721 |
|
|
|
805 |
stopping_criteria=[stopping_criteria]
|
806 |
)
|
807 |
|
808 |
+
if render:
|
809 |
+
print('==============rendering===============')
|
810 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
811 |
|
812 |
+
if outputs.endswith(stop_str):
|
813 |
+
outputs = outputs[:-len(stop_str)]
|
814 |
+
outputs = outputs.strip()
|
815 |
|
816 |
+
html_path = "./render_tools/" + "content-mmd-to-html.html"
|
817 |
+
html_path_2 = save_render_file
|
818 |
+
right_num = outputs.count('\\right')
|
819 |
+
left_num = outputs.count('\left')
|
820 |
|
821 |
+
if right_num != left_num:
|
822 |
+
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
|
823 |
|
824 |
|
825 |
+
outputs = outputs.replace('"', '``').replace('$', '')
|
826 |
|
827 |
+
outputs_list = outputs.split('\n')
|
828 |
+
gt= ''
|
829 |
+
for out in outputs_list:
|
830 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
831 |
|
832 |
+
gt = gt[:-2]
|
833 |
|
834 |
+
with smart_open(html_path, 'r') as web_f:
|
835 |
+
lines = web_f.read()
|
836 |
+
lines = lines.split("const text =")
|
837 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
838 |
|
839 |
+
with smart_open(html_path_2, 'w') as web_f_new:
|
840 |
+
web_f_new.write(new_web)
|