Oysiyl commited on
Commit
e32677f
1 Parent(s): e16347c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
2
  import gradio as gr
 
3
  from PIL import Image
4
  from transformers import AutoProcessor, AutoModelForCausalLM
5
 
 
 
6
  #workaround for unnecessary flash_attn requirement
7
  from unittest.mock import patch
8
  from transformers.dynamic_module_utils import get_imports
@@ -16,7 +19,7 @@ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
16
  return imports
17
 
18
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
19
- model = AutoModelForCausalLM.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", attn_implementation="sdpa", trust_remote_code=True)
20
 
21
  processor = AutoProcessor.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", trust_remote_code=True)
22
 
 
1
  import os
2
  import gradio as gr
3
+ import torch
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
  #workaround for unnecessary flash_attn requirement
10
  from unittest.mock import patch
11
  from transformers.dynamic_module_utils import get_imports
 
19
  return imports
20
 
21
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
22
+ model = AutoModelForCausalLM.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", attn_implementation="sdpa", trust_remote_code=True).to(device)
23
 
24
  processor = AutoProcessor.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", trust_remote_code=True)
25