| import torch | |
| from triton_flash_atn import _attention | |
| # Define dimensions | |
| batch_size = 2 | |
| num_heads = 4 | |
| seq_len = 128 | |
| head_dim = 64 | |
| # Create random input tensors for Q, K, V | |
| q = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
| dtype=torch.float16, device='cuda') | |
| k = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
| dtype=torch.float16, device='cuda') | |
| v = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
| dtype=torch.float16, device='cuda') | |
| # Define whether the attention is causal and the scaling factor | |
| causal = False | |
| sm_scale = 1.0 / (head_dim ** 0.5) | |
| # Apply flash attention | |
| attention = _attention.apply | |
| output = attention(q, k, v, causal, sm_scale) | |
| print(output) | |