request: Add flash attention 2.0 support for GPT2LMHeadModel

#75
by brresnic - opened
model = AutoModelForCausalLM.from_pretrained(
    my_GPT2LMHeadModel_checkpoint, 
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
)

throws the following error:

Error loading Flash_Model_2: GPT2LMHeadModel does not support Flash Attention 2.0 yet. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new

Hi @brresnic
Thanks for your interest! There is an ongoing effort to add FA2 to GPT2 here: https://github.com/huggingface/transformers/pull/27479
Note however since the model size is relatively small I don't expect very interesting speedups with FA2 + gpt2

Sign up or log in to comment