Spaces:
Runtime error
Runtime error
praysimanjuntak
commited on
Commit
•
2461c58
1
Parent(s):
c3286e1
Update llava/model/builder.py
Browse files- llava/model/builder.py +2 -1
llava/model/builder.py
CHANGED
@@ -16,13 +16,14 @@
|
|
16 |
import os
|
17 |
import warnings
|
18 |
import shutil
|
|
|
19 |
|
20 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
21 |
import torch
|
22 |
from llava.model import *
|
23 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
-
|
26 |
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
27 |
kwargs = {"device_map": device_map, **kwargs}
|
28 |
|
|
|
16 |
import os
|
17 |
import warnings
|
18 |
import shutil
|
19 |
+
import spaces
|
20 |
|
21 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
22 |
import torch
|
23 |
from llava.model import *
|
24 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
25 |
|
26 |
+
@spaces.GPU
|
27 |
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
28 |
kwargs = {"device_map": device_map, **kwargs}
|
29 |
|