Satandon1999 commited on
Commit
ed7de9a
1 Parent(s): f5a1a1b

Update triton_flash_blocksparse_attn.py

Browse files

Add suggestion similar to https://huggingface.co/THUDM/cogagent-chat-hf/blob/d519da3b191401234f4bd86ce1c287c61bc276a3/util.py#L210 to avoid error
```ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)```

Files changed (1) hide show
  1. triton_flash_blocksparse_attn.py +25 -24
triton_flash_blocksparse_attn.py CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
- _fwd_kernel[grid](
615
- q, k, v, sm_scale,
616
- layout_crow_indices,
617
- layout_col_indices,
618
- layout_crow_indices.stride(0), layout_crow_indices.stride(1),
619
- layout_col_indices.stride(0), layout_col_indices.stride(1),
620
- tmp, L, m,
621
- o,
622
- q.stride(0), q.stride(1), q.stride(2), q.stride(3),
623
- k.stride(0), k.stride(1), k.stride(2), k.stride(3),
624
- v.stride(0), v.stride(1), v.stride(2), v.stride(3),
625
- o.stride(0), o.stride(1), o.stride(2), o.stride(3),
626
- q.shape[0], q.shape[1], k.shape[2],
627
- k.shape[2] - q.shape[2],
628
- q_rounded_len,
629
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
630
- BLOCK_DMODEL=BLOCK_DMODEL,
631
- EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
632
- EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
633
- INFERENCE=inference,
634
- NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
635
- num_warps=num_warps,
636
- num_stages=num_stages,
637
- )
 
638
  if inference:
639
  L, m = None, None
640
 
 
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
+ with torch.cuda.device(q.device.index):
615
+ _fwd_kernel[grid](
616
+ q, k, v, sm_scale,
617
+ layout_crow_indices,
618
+ layout_col_indices,
619
+ layout_crow_indices.stride(0), layout_crow_indices.stride(1),
620
+ layout_col_indices.stride(0), layout_col_indices.stride(1),
621
+ tmp, L, m,
622
+ o,
623
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
624
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
625
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
626
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
627
+ q.shape[0], q.shape[1], k.shape[2],
628
+ k.shape[2] - q.shape[2],
629
+ q_rounded_len,
630
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
631
+ BLOCK_DMODEL=BLOCK_DMODEL,
632
+ EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
633
+ EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
634
+ INFERENCE=inference,
635
+ NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
636
+ num_warps=num_warps,
637
+ num_stages=num_stages,
638
+ )
639
  if inference:
640
  L, m = None, None
641