import torch from triton_flash_atn import _attention # Define dimensions batch_size = 2 num_heads = 4 seq_len = 128 head_dim = 64 # Create random input tensors for Q, K, V q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') k = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') v = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') # Define whether the attention is causal and the scaling factor causal = False sm_scale = 1.0 / (head_dim ** 0.5) # Apply flash attention attention = _attention.apply output = attention(q, k, v, causal, sm_scale) print(output)