CheckpointError in `triton_flash_blocksparse_attn.py` while finetuning

#18
by FremyCompany - opened

While trying to finetune this model, I encountered an error with the backward pass:

  File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/triton_flash_blocksparse_attn.py", line 904, in backward
    return _backward(ctx, do, *backward_layout)[:4]
  File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/triton_flash_blocksparse_attn.py", line 655, in _backward
    q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 1118, in unpack_hook
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Unpack is being triggered for a tensor that was already unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do so only once. Otherwise please open an issue with details on your use case.

Any idea how I could fix this issue?

Nevermind, after looking deeper into the other issues in the Phi3 repositories, I was able to locate that the error is related to the use of use_reentrant=False in the Trainer configuration, while use_reentrant=True is apparently required for Phi3 small.

Sign up or log in to comment