RuntimeError: FlashAttention only support fp16 and bf16 data type

#15
by Satandon1999 - opened

I am not able to resolve this error while trying to finetune this model.
I have loaded the model as bf16 using torch_dtype=torch.bfloat16 in the from_pretrained function. I have also added an explicit cast using model = model.to(torch.bfloat16).
The same exact code works flawlessly for the 'mini' version of the model, but not for this.
Any guidance would be greatly appreciated.
Thanks.

Sign up or log in to comment