praysimanjuntak commited on
Commit
2461c58
1 Parent(s): c3286e1

Update llava/model/builder.py

Browse files
Files changed (1) hide show
  1. llava/model/builder.py +2 -1
llava/model/builder.py CHANGED
@@ -16,13 +16,14 @@
16
  import os
17
  import warnings
18
  import shutil
 
19
 
20
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
  import torch
22
  from llava.model import *
23
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
-
26
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
  kwargs = {"device_map": device_map, **kwargs}
28
 
 
16
  import os
17
  import warnings
18
  import shutil
19
+ import spaces
20
 
21
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
22
  import torch
23
  from llava.model import *
24
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
 
26
+ @spaces.GPU
27
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
28
  kwargs = {"device_map": device_map, **kwargs}
29