Kernels
metal-flash-sdpa / tests /test_flash_attention.py
EricB's picture
EricB HF Staff
Add metal flash sdpa
364f72d
raw
history blame
37.5 kB
import torch
import pytest
import sdpa_flash
def create_cu_seqlens(seq_lengths):
"""Create cumulative sequence lengths tensor."""
cu_seqlens = [0]
for length in seq_lengths:
cu_seqlens.append(cu_seqlens[-1] + length)
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_single_sequence(dtype, head_dim):
"""Test Flash Attention with a single sequence."""
torch.manual_seed(42)
# Single sequence
seq_len = 32
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors in Flash Attention format
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
# Flash Attention computes attention separately for each head
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, h, :]
v_h = value[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 and float16)
if dtype == torch.bfloat16:
# Higher tolerance for head_dim=128 with bfloat16
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_variable_lengths(dtype, head_dim):
"""Test Flash Attention with variable-length sequences."""
torch.manual_seed(42)
# Variable sequence lengths
seq_lengths_q = [8, 16, 12]
seq_lengths_k = [10, 20, 15]
batch_size = len(seq_lengths_q)
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens(seq_lengths_q)
cu_seqlens_k = create_cu_seqlens(seq_lengths_k)
total_q = sum(seq_lengths_q)
total_k = sum(seq_lengths_k)
max_seqlen_q = max(seq_lengths_q)
max_seqlen_k = max(seq_lengths_k)
# Create input tensors
query = torch.randn(total_q, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth for each sequence
expected = torch.zeros_like(out)
for i in range(batch_size):
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
q_i = query[q_start:q_end]
k_i = key[k_start:k_end]
v_i = value[k_start:k_end]
# Compute attention for each head separately
for h in range(num_heads):
q_h = q_i[:, h, :] # [seq_len, head_dim]
k_h = k_i[:, h, :]
v_h = v_i[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[q_start:q_end, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 and float16)
if dtype == torch.bfloat16:
# Higher tolerance for bfloat16 with variable length sequences
rtol, atol = 2e-2, 2e-2
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_causal(dtype, head_dim):
"""Test Flash Attention with causal masking."""
torch.manual_seed(42)
# Test dimensions
seq_lengths = [16, 24]
batch_size = len(seq_lengths)
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention with causal mask
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
do_causal=True,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with causal mask
expected = torch.zeros_like(out)
for i in range(batch_size):
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
seq_len = end - start
q_i = query[start:end]
k_i = key[start:end]
v_i = value[start:end]
# Compute attention for each head separately
for h in range(num_heads):
q_h = q_i[:, h, :] # [seq_len, head_dim]
k_h = k_i[:, h, :]
v_h = v_i[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
# Apply causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
scores.masked_fill_(causal_mask, float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 and float16)
if dtype == torch.bfloat16:
# Higher tolerance for head_dim=128 with bfloat16
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_gqa(dtype, head_dim):
"""Test Flash Attention with Grouped Query Attention."""
torch.manual_seed(42)
# Test dimensions
seq_len = 32
num_heads = 8
num_kv_heads = 2 # GQA with 4:1 ratio
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with GQA
# Each query head attends to its corresponding kv head (with repetition)
expected = torch.zeros_like(query)
gqa_factor = num_heads // num_kv_heads
for h in range(num_heads):
kv_h = h // gqa_factor
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, kv_h, :]
v_h = value[:, kv_h, :]
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 and float16)
if dtype == torch.bfloat16:
# Higher tolerance for head_dim=128 with bfloat16
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_head_dimensions(head_dim):
"""Test Flash Attention with different supported head dimensions."""
torch.manual_seed(42)
# Test dimensions
seq_len = 16
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
key = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
value = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Basic check that output is not zeros
assert out.abs().max().item() > 0
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_flash_attention_large_head_dim(dtype):
"""Test Flash Attention with head_dim=128 specifically."""
torch.manual_seed(42)
# Test dimensions with head_dim=128
seq_lengths = [32, 64]
batch_size = len(seq_lengths)
num_heads = 8
head_dim = 128
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
expected = torch.zeros_like(out)
for i in range(batch_size):
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
q_i = query[start:end]
k_i = key[start:end]
v_i = value[start:end]
# Compute attention for each head separately
for h in range(num_heads):
q_h = q_i[:, h, :] # [seq_len, head_dim]
k_h = k_i[:, h, :]
v_h = v_i[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 with head_dim=128)
if dtype == torch.bfloat16:
# bfloat16 with head_dim=128 has known precision issues
rtol, atol = 2e-2, 2e-2
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_flash_attention_large_head_dim_causal(dtype):
"""Test Flash Attention with head_dim=128 and causal masking."""
torch.manual_seed(42)
# Test dimensions
seq_len = 48
num_heads = 4
head_dim = 128
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention with causal mask
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=True,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with causal mask
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, h, :]
v_h = value[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
# Apply causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
scores.masked_fill_(causal_mask, float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 with head_dim=128)
if dtype == torch.bfloat16:
# bfloat16 with head_dim=128 has known precision issues
rtol, atol = 2e-2, 2e-2
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
def test_flash_attention_large_head_dim_gqa():
"""Test Flash Attention with head_dim=128 and GQA."""
torch.manual_seed(42)
# Test dimensions
seq_len = 32
num_heads = 16
num_kv_heads = 4 # GQA with 4:1 ratio
head_dim = 128
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with GQA
expected = torch.zeros_like(query)
gqa_factor = num_heads // num_kv_heads
for h in range(num_heads):
kv_h = h // gqa_factor
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, kv_h, :]
v_h = value[:, kv_h, :]
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
def test_flash_attention_edge_cases():
"""Test Flash Attention edge cases."""
torch.manual_seed(42)
# Test 1: Single token sequence
query = torch.randn(1, 1, 64, device="mps")
key = torch.randn(1, 1, 64, device="mps")
value = torch.randn(1, 1, 64, device="mps")
cu_seqlens = create_cu_seqlens([1])
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=1,
max_seqlen_k=1,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
# With single token, output should equal value
torch.testing.assert_close(out, value, rtol=1e-5, atol=1e-5)
# Test 2: Empty sequence in batch
seq_lengths = [8, 0, 12] # Middle sequence is empty
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
query = torch.randn(total_tokens, 4, 64, device="mps")
key = torch.randn(total_tokens, 4, 64, device="mps")
value = torch.randn(total_tokens, 4, 64, device="mps")
out = torch.empty_like(query)
# This should handle empty sequences gracefully
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max(seq_lengths) if seq_lengths else 0,
max_seqlen_k=max(seq_lengths) if seq_lengths else 0,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
def test_flash_attention_unsupported_cases():
"""Test that unsupported cases raise appropriate errors."""
# Test 1: Unsupported head dimension
query = torch.randn(16, 4, 48, device="mps") # head_dim = 48 (not supported)
key = torch.randn(16, 4, 48, device="mps")
value = torch.randn(16, 4, 48, device="mps")
cu_seqlens = create_cu_seqlens([16])
out = torch.empty_like(query)
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"):
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=16,
max_seqlen_k=16,
do_causal=False,
scale=0.144,
softcapping=1.0,
)
# Test 2: Calling function with wrong number of arguments
query = torch.randn(16, 4, 64, device="mps")
key = torch.randn(16, 4, 64, device="mps")
value = torch.randn(16, 4, 64, device="mps")
mask = torch.randn(1, 1, 16, 16, device="mps")
cu_seqlens = create_cu_seqlens([16])
out = torch.empty_like(query)
# The function signature no longer accepts mask parameter
with pytest.raises(TypeError):
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=16,
max_seqlen_k=16,
mask=mask, # This parameter doesn't exist anymore
do_causal=False,
scale=0.125,
softcapping=1.0,
)
# Test 3: Wrong dtype for cu_seqlens (should be int32)
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
# This will silently fail (output will be unchanged)
# We can detect this by initializing output to a known value
out = torch.full_like(query, -999.0)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_wrong,
cu_seqlens_k=cu_seqlens_wrong,
max_seqlen_q=16,
max_seqlen_k=16,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
# Check that output wasn't modified (kernel didn't run)
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_small_sequences(dtype, head_dim):
"""Test Flash Attention with small sequence lengths (2-8)."""
torch.manual_seed(42)
# Test different small sequence lengths
for seq_len in [2, 4, 6, 8]:
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, h, :]
v_h = value[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16)
if dtype == torch.bfloat16:
rtol, atol = 2e-2, 2e-2
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_cross_attention(dtype, head_dim):
"""Test Flash Attention with different q_seq and k_seq (cross-attention)."""
torch.manual_seed(42)
# Test various q_seq, k_seq combinations
test_cases = [
(16, 32), # q_seq < k_seq
(32, 16), # q_seq > k_seq
(8, 128), # large difference
(1, 64), # single query token
]
for q_seq, k_seq in test_cases:
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens([q_seq])
cu_seqlens_k = create_cu_seqlens([k_seq])
# Create input tensors
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_seq,
max_seqlen_k=k_seq,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [q_seq, head_dim]
k_h = key[:, h, :] # [k_seq, head_dim]
v_h = value[:, h, :] # [k_seq, head_dim]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16)
if dtype == torch.bfloat16:
rtol, atol = 2e-2, 2e-2
elif dtype == torch.float16:
rtol, atol = 2e-3, 2e-3
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_flash_attention_large_sequences(dtype):
"""Test Flash Attention with large k_seq (>= 1024)."""
torch.manual_seed(42)
# Test dimensions - large k_seq to test 2-pass algorithms
q_seq = 32
k_seq = 2048 # Large k_seq
num_heads = 4
head_dim = 64 # Use smaller head_dim to avoid memory issues
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens([q_seq])
cu_seqlens_k = create_cu_seqlens([k_seq])
# Create input tensors
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_seq,
max_seqlen_k=k_seq,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [q_seq, head_dim]
k_h = key[:, h, :] # [k_seq, head_dim]
v_h = value[:, h, :] # [k_seq, head_dim]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for large sequences)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 3e-2
elif dtype == torch.float16:
rtol, atol = 5e-3, 5e-3
else:
rtol, atol = 2e-3, 2e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("gqa_ratio", [2, 4, 8])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128])
def test_flash_attention_gqa_ratios(gqa_ratio, head_dim):
"""Test Flash Attention with different GQA ratios."""
torch.manual_seed(42)
# Test dimensions
seq_len = 32
num_heads = 16
num_kv_heads = num_heads // gqa_ratio
dtype = torch.float32
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with GQA
expected = torch.zeros_like(query)
gqa_factor = num_heads // num_kv_heads
for h in range(num_heads):
kv_h = h // gqa_factor
q_h = query[:, h, :] # [seq_len, head_dim]
k_h = key[:, kv_h, :]
v_h = value[:, kv_h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Check results
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
def test_flash_attention_single_query_token():
"""Test Flash Attention with single query token (q_seq = 1)."""
torch.manual_seed(42)
# Test dimensions - single query token
q_seq = 1
k_seq = 64
num_heads = 8
head_dim = 64
dtype = torch.float32
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens([q_seq])
cu_seqlens_k = create_cu_seqlens([k_seq])
# Create input tensors
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_seq,
max_seqlen_k=k_seq,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# With single token, output should be weighted average of values
expected = torch.zeros_like(out)
for h in range(num_heads):
q_h = query[:, h, :] # [1, head_dim]
k_h = key[:, h, :] # [k_seq, head_dim]
v_h = value[:, h, :] # [k_seq, head_dim]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
def test_flash_attn_varlen_func():
"""Test the flash_attn_varlen_func compatibility function."""
torch.manual_seed(42)
# Test dimensions
seq_lengths = [8, 12]
num_heads = 4
head_dim = 64
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
q = torch.randn(total_tokens, num_heads, head_dim, device="mps")
k = torch.randn(total_tokens, num_heads, head_dim, device="mps")
v = torch.randn(total_tokens, num_heads, head_dim, device="mps")
# Call the compatibility function
out = sdpa_flash.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
softmax_scale=None, # Will use 1/sqrt(head_dim)
causal=False,
)
# Check that output has correct shape and is not zeros
assert out.shape == q.shape
assert out.abs().max().item() > 0
# Test with causal
out_causal = sdpa_flash.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
softmax_scale=0.125,
causal=True,
)
assert out_causal.shape == q.shape
assert out_causal.abs().max().item() > 0
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
def test_flash_attention_softcapping(dtype, head_dim):
"""Test Flash Attention with softcapping."""
torch.manual_seed(42)
# Test dimensions
seq_lengths = [32, 24]
num_heads = 4
softcapping = 50.0
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention with softcapping
out = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
do_causal=False,
scale=scale,
softcapping=softcapping,
)
# Compute ground truth with softcapping
# The kernel applies: softmax(tanh(qk^T*scale/cap)*cap)v
expected = torch.zeros_like(query)
for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
q_seq = query[start:end]
k_seq = key[start:end]
v_seq = value[start:end]
for h in range(num_heads):
q_h = q_seq[:, h, :]
k_h = k_seq[:, h, :]
v_h = v_seq[:, h, :]
# Apply softcapping formula
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * (scale / softcapping)
scores = torch.tanh(scores) * softcapping
attn_weights = torch.softmax(scores, dim=-1)
expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
# Check results (higher tolerance for bfloat16 and softcapping)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 3e-2
elif dtype == torch.float16:
rtol, atol = 2e-2, 2e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_flash_attention_softcapping_edge_cases(dtype):
"""Test Flash Attention softcapping with edge cases."""
torch.manual_seed(42)
# Test with softcapping = 1.0 (no softcapping)
seq_len = 16
num_heads = 2
head_dim = 64
cu_seqlens = create_cu_seqlens([seq_len])
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
scale = 1.0 / (head_dim ** 0.5)
# With softcapping = 1.0 (no effect)
out_no_cap = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out_no_cap,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Regular computation without softcapping
expected = torch.zeros_like(query)
for h in range(num_heads):
q_h = query[:, h, :]
k_h = key[:, h, :]
v_h = value[:, h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
# Should be identical when softcapping = 1.0
rtol, atol = (2e-2, 2e-2) if dtype != torch.float32 else (1e-3, 1e-3)
torch.testing.assert_close(out_no_cap, expected, rtol=rtol, atol=atol)
# Test with very large softcapping value
out_large_cap = torch.empty_like(query)
sdpa_flash.flash_attention_varlen(
out=out_large_cap,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1000.0,
)
# With very large softcapping, should be close to no softcapping
torch.testing.assert_close(out_large_cap, expected, rtol=rtol, atol=atol)