|
|
import torch |
|
|
import pytest |
|
|
import sdpa_flash |
|
|
|
|
|
|
|
|
def create_cu_seqlens(seq_lengths): |
|
|
"""Create cumulative sequence lengths tensor.""" |
|
|
cu_seqlens = [0] |
|
|
for length in seq_lengths: |
|
|
cu_seqlens.append(cu_seqlens[-1] + length) |
|
|
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps") |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_single_sequence(dtype, head_dim): |
|
|
"""Test Flash Attention with a single sequence.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 32 |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_variable_lengths(dtype, head_dim): |
|
|
"""Test Flash Attention with variable-length sequences.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths_q = [8, 16, 12] |
|
|
seq_lengths_k = [10, 20, 15] |
|
|
batch_size = len(seq_lengths_q) |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens(seq_lengths_q) |
|
|
cu_seqlens_k = create_cu_seqlens(seq_lengths_k) |
|
|
|
|
|
total_q = sum(seq_lengths_q) |
|
|
total_k = sum(seq_lengths_k) |
|
|
max_seqlen_q = max(seq_lengths_q) |
|
|
max_seqlen_k = max(seq_lengths_k) |
|
|
|
|
|
|
|
|
query = torch.randn(total_q, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=max_seqlen_q, |
|
|
max_seqlen_k=max_seqlen_k, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for i in range(batch_size): |
|
|
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item() |
|
|
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item() |
|
|
|
|
|
q_i = query[q_start:q_end] |
|
|
k_i = key[k_start:k_end] |
|
|
v_i = value[k_start:k_end] |
|
|
|
|
|
|
|
|
for h in range(num_heads): |
|
|
q_h = q_i[:, h, :] |
|
|
k_h = k_i[:, h, :] |
|
|
v_h = v_i[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[q_start:q_end, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_causal(dtype, head_dim): |
|
|
"""Test Flash Attention with causal masking.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths = [16, 24] |
|
|
batch_size = len(seq_lengths) |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
do_causal=True, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for i in range(batch_size): |
|
|
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item() |
|
|
seq_len = end - start |
|
|
|
|
|
q_i = query[start:end] |
|
|
k_i = key[start:end] |
|
|
v_i = value[start:end] |
|
|
|
|
|
|
|
|
for h in range(num_heads): |
|
|
q_h = q_i[:, h, :] |
|
|
k_h = k_i[:, h, :] |
|
|
v_h = v_i[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
|
|
|
|
|
|
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() |
|
|
scores.masked_fill_(causal_mask, float("-inf")) |
|
|
|
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[start:end, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_gqa(dtype, head_dim): |
|
|
"""Test Flash Attention with Grouped Query Attention.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 32 |
|
|
num_heads = 8 |
|
|
num_kv_heads = 2 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
gqa_factor = num_heads // num_kv_heads |
|
|
|
|
|
for h in range(num_heads): |
|
|
kv_h = h // gqa_factor |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, kv_h, :] |
|
|
v_h = value[:, kv_h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_head_dimensions(head_dim): |
|
|
"""Test Flash Attention with different supported head dimensions.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 16 |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
key = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
value = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
assert out.abs().max().item() > 0 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_flash_attention_large_head_dim(dtype): |
|
|
"""Test Flash Attention with head_dim=128 specifically.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths = [32, 64] |
|
|
batch_size = len(seq_lengths) |
|
|
num_heads = 8 |
|
|
head_dim = 128 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for i in range(batch_size): |
|
|
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item() |
|
|
|
|
|
q_i = query[start:end] |
|
|
k_i = key[start:end] |
|
|
v_i = value[start:end] |
|
|
|
|
|
|
|
|
for h in range(num_heads): |
|
|
q_h = q_i[:, h, :] |
|
|
k_h = k_i[:, h, :] |
|
|
v_h = v_i[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[start:end, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_flash_attention_large_head_dim_causal(dtype): |
|
|
"""Test Flash Attention with head_dim=128 and causal masking.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 48 |
|
|
num_heads = 4 |
|
|
head_dim = 128 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=True, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
|
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
|
|
|
|
|
|
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() |
|
|
scores.masked_fill_(causal_mask, float("-inf")) |
|
|
|
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
def test_flash_attention_large_head_dim_gqa(): |
|
|
"""Test Flash Attention with head_dim=128 and GQA.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 32 |
|
|
num_heads = 16 |
|
|
num_kv_heads = 4 |
|
|
head_dim = 128 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
gqa_factor = num_heads // num_kv_heads |
|
|
|
|
|
for h in range(num_heads): |
|
|
kv_h = h // gqa_factor |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, kv_h, :] |
|
|
v_h = value[:, kv_h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) |
|
|
|
|
|
|
|
|
def test_flash_attention_edge_cases(): |
|
|
"""Test Flash Attention edge cases.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
query = torch.randn(1, 1, 64, device="mps") |
|
|
key = torch.randn(1, 1, 64, device="mps") |
|
|
value = torch.randn(1, 1, 64, device="mps") |
|
|
cu_seqlens = create_cu_seqlens([1]) |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=1, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, value, rtol=1e-5, atol=1e-5) |
|
|
|
|
|
|
|
|
seq_lengths = [8, 0, 12] |
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
|
|
|
query = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
key = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
value = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
|
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max(seq_lengths) if seq_lengths else 0, |
|
|
max_seqlen_k=max(seq_lengths) if seq_lengths else 0, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
def test_flash_attention_unsupported_cases(): |
|
|
"""Test that unsupported cases raise appropriate errors.""" |
|
|
|
|
|
|
|
|
query = torch.randn(16, 4, 48, device="mps") |
|
|
key = torch.randn(16, 4, 48, device="mps") |
|
|
value = torch.randn(16, 4, 48, device="mps") |
|
|
cu_seqlens = create_cu_seqlens([16]) |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"): |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=16, |
|
|
max_seqlen_k=16, |
|
|
do_causal=False, |
|
|
scale=0.144, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
query = torch.randn(16, 4, 64, device="mps") |
|
|
key = torch.randn(16, 4, 64, device="mps") |
|
|
value = torch.randn(16, 4, 64, device="mps") |
|
|
mask = torch.randn(1, 1, 16, 16, device="mps") |
|
|
cu_seqlens = create_cu_seqlens([16]) |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
|
|
|
with pytest.raises(TypeError): |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=16, |
|
|
max_seqlen_k=16, |
|
|
mask=mask, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps") |
|
|
|
|
|
|
|
|
|
|
|
out = torch.full_like(query, -999.0) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_wrong, |
|
|
cu_seqlens_k=cu_seqlens_wrong, |
|
|
max_seqlen_q=16, |
|
|
max_seqlen_k=16, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_small_sequences(dtype, head_dim): |
|
|
"""Test Flash Attention with small sequence lengths (2-8).""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
for seq_len in [2, 4, 6, 8]: |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_cross_attention(dtype, head_dim): |
|
|
"""Test Flash Attention with different q_seq and k_seq (cross-attention).""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
(16, 32), |
|
|
(32, 16), |
|
|
(8, 128), |
|
|
(1, 64), |
|
|
] |
|
|
|
|
|
for q_seq, k_seq in test_cases: |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens([q_seq]) |
|
|
cu_seqlens_k = create_cu_seqlens([k_seq]) |
|
|
|
|
|
|
|
|
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=q_seq, |
|
|
max_seqlen_k=k_seq, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
else: |
|
|
rtol, atol = 1e-3, 1e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_flash_attention_large_sequences(dtype): |
|
|
"""Test Flash Attention with large k_seq (>= 1024).""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
q_seq = 32 |
|
|
k_seq = 2048 |
|
|
num_heads = 4 |
|
|
head_dim = 64 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens([q_seq]) |
|
|
cu_seqlens_k = create_cu_seqlens([k_seq]) |
|
|
|
|
|
|
|
|
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=q_seq, |
|
|
max_seqlen_k=k_seq, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
rtol, atol = 3e-2, 3e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 5e-3, 5e-3 |
|
|
else: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("gqa_ratio", [2, 4, 8]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128]) |
|
|
def test_flash_attention_gqa_ratios(gqa_ratio, head_dim): |
|
|
"""Test Flash Attention with different GQA ratios.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 32 |
|
|
num_heads = 16 |
|
|
num_kv_heads = num_heads // gqa_ratio |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
gqa_factor = num_heads // num_kv_heads |
|
|
|
|
|
for h in range(num_heads): |
|
|
kv_h = h // gqa_factor |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, kv_h, :] |
|
|
v_h = value[:, kv_h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) |
|
|
|
|
|
|
|
|
def test_flash_attention_single_query_token(): |
|
|
"""Test Flash Attention with single query token (q_seq = 1).""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
q_seq = 1 |
|
|
k_seq = 64 |
|
|
num_heads = 8 |
|
|
head_dim = 64 |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens([q_seq]) |
|
|
cu_seqlens_k = create_cu_seqlens([k_seq]) |
|
|
|
|
|
|
|
|
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=q_seq, |
|
|
max_seqlen_k=k_seq, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) |
|
|
|
|
|
|
|
|
def test_flash_attn_varlen_func(): |
|
|
"""Test the flash_attn_varlen_func compatibility function.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths = [8, 12] |
|
|
num_heads = 4 |
|
|
head_dim = 64 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
q = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
k = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
v = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
|
|
|
|
|
|
out = sdpa_flash.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=None, |
|
|
causal=False, |
|
|
) |
|
|
|
|
|
|
|
|
assert out.shape == q.shape |
|
|
assert out.abs().max().item() > 0 |
|
|
|
|
|
|
|
|
out_causal = sdpa_flash.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=0.125, |
|
|
causal=True, |
|
|
) |
|
|
|
|
|
assert out_causal.shape == q.shape |
|
|
assert out_causal.abs().max().item() > 0 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
def test_flash_attention_softcapping(dtype, head_dim): |
|
|
"""Test Flash Attention with softcapping.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths = [32, 24] |
|
|
num_heads = 4 |
|
|
softcapping = 50.0 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=softcapping, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
|
|
|
for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): |
|
|
q_seq = query[start:end] |
|
|
k_seq = key[start:end] |
|
|
v_seq = value[start:end] |
|
|
|
|
|
for h in range(num_heads): |
|
|
q_h = q_seq[:, h, :] |
|
|
k_h = k_seq[:, h, :] |
|
|
v_h = v_seq[:, h, :] |
|
|
|
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * (scale / softcapping) |
|
|
scores = torch.tanh(scores) * softcapping |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[start:end, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
rtol, atol = 3e-2, 3e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
else: |
|
|
rtol, atol = 1e-2, 1e-2 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_flash_attention_softcapping_edge_cases(dtype): |
|
|
"""Test Flash Attention softcapping with edge cases.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_len = 16 |
|
|
num_heads = 2 |
|
|
head_dim = 64 |
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out_no_cap = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out_no_cap, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
for h in range(num_heads): |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, h, :] |
|
|
v_h = value[:, h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
|
|
|
rtol, atol = (2e-2, 2e-2) if dtype != torch.float32 else (1e-3, 1e-3) |
|
|
torch.testing.assert_close(out_no_cap, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
out_large_cap = torch.empty_like(query) |
|
|
sdpa_flash.flash_attention_varlen( |
|
|
out=out_large_cap, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1000.0, |
|
|
) |
|
|
|
|
|
|
|
|
torch.testing.assert_close(out_large_cap, expected, rtol=rtol, atol=atol) |