ucaslcl commited on
Commit
2035077
1 Parent(s): 0b6cdba

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +10 -4
modeling_GOT.py CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False):
488
 
489
  self.disable_torch_init()
490
 
@@ -495,7 +495,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
495
 
496
  image_token_len = 256
497
 
498
- image = self.load_image(image_file)
 
 
 
499
 
500
  w, h = image.size
501
 
@@ -713,7 +716,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
713
  return processed_images
714
 
715
 
716
- def chat_crop(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False):
717
  # Model
718
  self.disable_torch_init()
719
  multi_page=False
@@ -749,7 +752,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
749
 
750
  else:
751
  qs = 'OCR with format upon the patch reference: '
752
- img = self.load_image(image_file)
 
 
 
753
  sub_images = self.dynamic_preprocess(img)
754
  ll = len(sub_images)
755
 
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False):
488
 
489
  self.disable_torch_init()
490
 
 
495
 
496
  image_token_len = 256
497
 
498
+ if gradio_input:
499
+ image = image_file.copy()
500
+ else:
501
+ image = self.load_image(image_file)
502
 
503
  w, h = image.size
504
 
 
716
  return processed_images
717
 
718
 
719
+ def chat_crop(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False, gradio_input=False):
720
  # Model
721
  self.disable_torch_init()
722
  multi_page=False
 
752
 
753
  else:
754
  qs = 'OCR with format upon the patch reference: '
755
+ if gradio_input:
756
+ img = image_file.copy()
757
+ else:
758
+ img = self.load_image(image_file)
759
  sub_images = self.dynamic_preprocess(img)
760
  ll = len(sub_images)
761