Handling `flash_attn` Dependency for Non-GPU Environments

#4
by giacomopedemonte - opened

Discussion: Handling flash_attn Dependency for Non-GPU Environments

I encountered an issue when running the code from the repo without a GPU. The original code requests the flash_attn module, which is specific to CUDA execution. This can cause problems for users who don't have a GPU. The problem is also related to the version of the transformers library, which may require this package. By removing flash_attn from the dependencies dynamically, I was able to get my code working fine.

Here is the original repo code:

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings

# disable some warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')

# set device
torch.set_default_device('cuda')  # or 'cpu'

model_name = 'qnguyen3/nanoLLaVA-1.5'

# create model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map='auto',
    trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True)

# text prompt
prompt = 'Describe this image in detail'

messages = [
    {"role": "user", "content": f'<image>\n{prompt}'}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

print(text)

text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)

# image, sample images can be found in images folder
image = Image.open('/path/to/image.png')
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)

# generate
output_ids = model.generate(
    input_ids,
    images=image_tensor,
    max_new_tokens=2048,
    use_cache=True)[0]

print(tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip())

This code, especially the part regarding AutoModelForCausalLM assumes a CUDA environment and includes dependencies that may not be available in a CPU-only setup. To resolve this, I modified the code to dynamically remove flash_attn from the imports if CUDA is not available. Here is the updated code:

from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    imports = get_imports(filename)
    if not torch.cuda.is_available() and "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

model_name = 'qnguyen3/nanoLLaVA-1.5'

# create model
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map='auto' if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )

# text prompt
prompt = 'Describe this image in detail'

messages = [
    {"role": "user", "content": f'<image>\n{prompt}'}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

print(text)

text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)

# image, sample images can be found in images folder
image = Image.open('images/people_crossing.jpg')
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)

# generate
output_ids = model.generate(
    input_ids.to(device),
    images=image_tensor.to(device),
    max_new_tokens=2048,
    use_cache=True
)[0]

print(tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip())

By applying this change, I was able to run the code successfully on a non-GPU machine. This solution may be useful for other users who are facing similar issues. The problem seems to stem from package dependencies, which are not always easy to resolve. Therefore, dynamically handling imports based on the available hardware is a practical workaround.

I hope this helps other users who might be in the same situation.

Thank you! I just posted about this problem (https://discuss.huggingface.co/t/best-practices-to-use-models-requiring-flash-attn-on-apple-silicon-macs-or-non-cuda/97562/2) asking for a general solution. This is fantastic. I really wish this didn't require monkey patching to solve such a simple usecase.

@kerrmetric I'm glad you found my solution helpful. Handling dependencies dynamically is crucial for ensuring code compatibility across different hardware setups but could be very frustrating. I tested my solution on Windows, and while I believe it should also work on Mac, I cannot guarantee it. Thanks again for joining the discussion!

Thanks a lot for sharing but this solution is not working in the case of microsoft/Florence-2-large model.

@muaz1236 Thanks for pointing that out! I understand your concern, and it's crucial to ensure our solutions are robust and adaptable across various models. Let's focus on making it work specifically for the microsoft/florence-* models.

The issue seems to stem from the naming of the file in which the import function is being used. If the file name conflicts with any package retrival processes, it might not correctly bypass the flash_attn dependency. To resolve this, we need to ensure that our import function accurately identifies and cleans the imports based on the file name.

Let's adjust our approach to use a specific file name, like main.py, to see if that resolves the issue:

import os
from unittest.mock import patch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    if os.path.basename(filename) != "main.py":
        return get_imports(filename)
    imports = get_imports(filename)
    if "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

However, if you continue to encounter the flash_attn import error, let's try changing the file name to something less likely to conflict with existing packages, like modelling_florence2.py:

import os
from unittest.mock import patch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    if os.path.basename(filename) != "modelling_florence2.py":
        return get_imports(filename)
    imports = get_imports(filename)
    if "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

Using modelling_florence2.py, the execution should proceed successfully, as demonstrated below:

2024-07-18 17:39:04.126830: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
...
{'<DESCRIPTION>': 'woman and dog on beach'}

The key takeaway here is that the file name plays a crucial role in avoiding conflicts with package retrieval processes. Make sure to choose a unique name for your file to ensure the custom import function works as intended.

So you can use the version that I shared above by making sure to name the file correctly like we talked about here.

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    imports = get_imports(filename)
    if not torch.cuda.is_available() and "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

Thanks again for highlighting this, and I hope this helps others facing similar issues. If you have any more questions or suggestions, feel free to share!

Super helpful but holy LLM written answer batman

"Thanks for pointing that out! I understand your concern, and it's crucial to ensure our solutions are robust and adaptable across various models. Let's focus on making it work specifically for the microsoft/florence-* models.

The issue seems to stem from the naming of the file in which the import function is being used. If the file name conflicts with any package retrival processes, it might not correctly bypass the flash_attn dependency. To resolve this, we need to ensure that our import function accurately identifies and cleans the imports based on the file name.

Let's adjust our approach to use a specific file name, like main.py, to see if that resolves the issue:"

And why is this wrong? I wrote it clearly with the usage of LLMs, in a way to make understand why we get the error using main.py instead other file names, I have described all the steps that I have done personally and in this way is clear every change what mean and why.

It's not bad, but I'd love for the norm to be clear demarcation between what is thoughtful, original, human content and LLM generated text. Blurring the lines makes it hard to appriciate and consider your helpful work.

I'll keep your comment in my bag absolutely, I will use them just for a touch of ispiration but in this case what you see is very very near to what I wrote by myself, maybe is because I'm interacting with LLMs all the time😂

thank you for the code. honestly i just completely forget about the flash attention stuff only works for nvidia/amd

giacomopedemonte changed discussion status to closed
giacomopedemonte changed discussion status to open

I leave it open to further discussion on this topic for the moment, maybe other error relative on GPU dependencies will arise

Thank you the tutorial worked i guess (code is still running without throwing errors) for Florence 2.
Just thrown this error:
ValueError: Florence2ForConditionalGeneration does not support device_map='auto'. To implement support, the model class needs to implement the _no_split_modules attribute.

I guess flash_attn issue is solved, this is a new unrelated error.

In my case I don't have a device_map usage and I think is useless if you do .to(device) explicitly in this way:

import os
from unittest.mock import patch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_imports
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    imports = get_imports(filename)
    if not torch.cuda.is_available() and "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True).to(device).eval()
    processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

prompt = "<OD>"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt")

generated_ids = model.generate(
    input_ids=inputs["input_ids"],
    pixel_values=inputs["pixel_values"],
    max_new_tokens=1024,
    num_beams=3,
    do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))

print(parsed_answer)

Another option if you still need to use device_map, maybe check also the version of transformers can be useful, with pip freeze I get this requirements (there's a lot of noise so sorry about that):

absl-py==2.1.0
accelerate==0.32.1
addict==2.4.0
aiohttp==3.9.5
aiosignal==1.3.1
aliyun-python-sdk-core==2.15.1
aliyun-python-sdk-kms==2.16.3
annotated-types==0.7.0
antlr4-python3-runtime==4.8
anyio==4.4.0
asttokens==2.4.1
astunparse==1.6.3
attrs==23.2.0
audioread==3.0.1
beautifulsoup4==4.12.3
bert-score==0.3.13
bitarray==2.9.2
blinker==1.8.2
blis==0.7.11
boto3==1.34.145
botocore==1.34.145
cachetools==5.3.3
cairocffi==1.7.1
CairoSVG==2.7.1
catalogue==2.0.10
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
cloudpathlib==0.18.1
cloudpickle==2.2.1
colorama==0.4.6
comm==0.2.2
confection==0.1.5
contourpy==1.2.0
crcmod==1.7
cryptography==42.0.8
cssselect2==0.7.0
cycler==0.12.1
cymem==2.0.8
Cython==3.0.10
datasets==2.18.0
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
diagrams==0.23.4
dill==0.3.8
distro==1.9.0
dnspython==2.6.1
docker==7.1.0
einops==0.8.0
email_validator==2.1.1
etils==1.9.1
executing==2.0.1
fairseq==0.12.2
fastapi==0.111.0
fastapi-cli==0.0.4
fastjsonschema==2.20.0
filelock==3.14.0
Flask==3.0.3
Flask-Cors==4.0.1
flatbuffers==24.3.25
flax==0.8.4
fonttools==4.47.2
frozenlist==1.4.1
fsspec==2024.2.0
ftfy==6.2.0
gast==0.5.4
gensim==4.3.3
google==3.0.0
google-ai-generativelanguage==0.6.6
google-api-core==2.19.1
google-api-python-client==2.136.0
google-auth==2.31.0
google-auth-httplib2==0.2.0
google-crc32c==1.5.0
google-generativeai==0.7.1
google-pasta==0.2.0
googleapis-common-protos==1.63.2
graphviz==0.20.3
grpcio==1.64.1
grpcio-status==1.62.2
h11==0.14.0
h5py==3.11.0
httpcore==1.0.5
httplib2==0.22.0
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.0
hydra-core==1.0.7
idna==3.6
imageio==2.34.1
importlib-metadata==6.11.0
importlib_resources==6.4.0
intel-openmp==2021.4.0
ipykernel==6.29.4
ipython==8.24.0
itsdangerous==2.1.2
jax==0.4.28
jaxlib==0.4.28
jedi==0.19.1
Jinja2==3.1.4
jmespath==0.10.0
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client==8.6.1
jupyter_core==5.7.2
kagglehub==0.2.5
keras==3.3.3
kiwisolver==1.4.5
langcodes==3.4.0
language_data==1.2.0
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
llvmlite==0.42.0
lxml==5.2.2
marisa-trie==1.2.0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mkl==2021.4.0
ml-dtypes==0.3.2
modelscope==1.15.0
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.5
multiprocess==0.70.16
murmurhash==1.0.10
namex==0.0.8
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
numba==0.59.1
numpy==1.26.3
omegaconf==2.0.6
open-clip-torch==2.24.0
openai==1.30.5
opencv-python==4.10.0.82
opt-einsum==3.3.0
optax==0.2.2
optree==0.11.0
orbax-checkpoint==0.5.15
orjson==3.10.4
oss2==2.18.5
packaging==23.2
pandas==2.2.0
parso==0.8.4
pathos==0.3.2
pillow==10.2.0
platformdirs==4.2.2
pooch==1.8.2
portalocker==2.8.2
pox==0.3.4
ppft==1.7.6.8
preshed==3.0.9
prompt-toolkit==3.0.43
proto-plus==1.24.0
protobuf==4.25.1
psutil==5.9.8
pure-eval==0.2.2
pyarrow==16.1.0
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1_modules==0.4.0
pycocoevalcap==1.2
pycocotools==2.0.8
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.7.2
pydantic_core==2.18.3
Pygments==2.18.0
PyJWT==2.8.0
pyparsing==3.1.1
python-dateutil==2.8.2
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2023.3.post1
pywin32==306
PyYAML==6.0.1
pyzmq==26.0.3
referencing==0.35.1
regex==2024.5.15
requests==2.31.0
rich==13.7.1
rouge_score==0.1.2
rpds-py==0.19.0
rsa==4.9
s3transfer==0.10.2
sacrebleu==2.4.2
sacremoses==0.1.1
safetensors==0.4.3
sagemaker==2.226.1
schema==0.7.7
scikit-image==0.23.2
scikit-learn==1.5.0
scipy==1.13.0
seaborn==0.13.2
sentence-transformers==3.0.1
sentencepiece==0.2.0
setuptools==69.5.1
shellingham==1.5.4
simplejson==3.19.2
six==1.16.0
smart-open==7.0.4
smdebug-rulesconfig==1.0.1
sniffio==1.3.1
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
spacy==3.7.5
spacy-legacy==3.0.12
spacy-loggers==1.0.5
srsly==2.4.8
stack-data==0.6.3
starlette==0.37.2
sympy==1.12
tabulate==0.9.0
tbb==2021.12.0
tblib==3.0.0
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-hub==0.16.1
tensorflow-intel==2.16.1
tensorstore==0.1.60
termcolor==2.4.0
tf-slim==1.1.0
tf_keras==2.16.0
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.5.22
timm==1.0.3
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
toolz==0.12.1
torch==2.3.1
torchaudio==2.3.0
torchvision==0.18.1
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
transformers==4.41.2
typed_ast==1.5.5
typer==0.12.3
typing_extensions==4.9.0
tzdata==2023.4
ujson==5.10.0
uritemplate==4.1.1
urllib3==2.1.0
uvicorn==0.30.1
wasabi==1.1.3
watchdog==3.0.0
watchfiles==0.22.0
wcwidth==0.2.13
weasel==0.4.1
webencodings==0.5.1
websockets==12.0
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
xxhash==3.4.1
yapf==0.40.2
yarl==1.9.4
zipp==3.19.2

However changing to a specific version could lead to other dependencies errors that will not make you complete the building wheel of dependencies, so I think you need to consider better the revision of your code.

Let me know if this helped you!

Thanks for the guidance. In my case in Windows 10, ONLY the following worked where the conditional statement is removed:

if os.path.basename(filename) != "modeling_florence2.py":
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)```

I think that is because AutoModelForCasualLM, when detect Microsoft Florence, downloads the file modeling_florence, and with that code you ensure that you remove the flash_attn import from that file.

Thanks for this man!

I've adapted it to work with openbmb/MiniCPM-V-2_6 ( https://huggingface.co/openbmb/MiniCPM-V-2_6 ) :

from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    imports = get_imports(filename)
    if not torch.cuda.is_available() and "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

model_name = 'openbmb/MiniCPM-V-2_6'

# create model
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map='auto' if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )


image = Image.open('1.png').convert('RGB')
question = 'What is in the image?'
msgs = [{'role': 'user', 'content': [image, question]}]

res = model.chat(
    image=None,
    msgs=msgs,
    tokenizer=tokenizer
)
print(res)

.. hope this saves the next guy a few hours :)

Cheers!

Sign up or log in to comment