import torch import pytest import sdpa_flash def create_cu_seqlens(seq_lengths): """Create cumulative sequence lengths tensor.""" cu_seqlens = [0] for length in seq_lengths: cu_seqlens.append(cu_seqlens[-1] + length) return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_single_sequence(dtype, head_dim): """Test Flash Attention with a single sequence.""" torch.manual_seed(42) # Single sequence seq_len = 32 num_heads = 4 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors in Flash Attention format query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth # Flash Attention computes attention separately for each head expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, h, :] v_h = value[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 and float16) if dtype == torch.bfloat16: # Higher tolerance for head_dim=128 with bfloat16 rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_variable_lengths(dtype, head_dim): """Test Flash Attention with variable-length sequences.""" torch.manual_seed(42) # Variable sequence lengths seq_lengths_q = [8, 16, 12] seq_lengths_k = [10, 20, 15] batch_size = len(seq_lengths_q) num_heads = 4 # Create cumulative sequence lengths cu_seqlens_q = create_cu_seqlens(seq_lengths_q) cu_seqlens_k = create_cu_seqlens(seq_lengths_k) total_q = sum(seq_lengths_q) total_k = sum(seq_lengths_k) max_seqlen_q = max(seq_lengths_q) max_seqlen_k = max(seq_lengths_k) # Create input tensors query = torch.randn(total_q, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth for each sequence expected = torch.zeros_like(out) for i in range(batch_size): q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item() k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item() q_i = query[q_start:q_end] k_i = key[k_start:k_end] v_i = value[k_start:k_end] # Compute attention for each head separately for h in range(num_heads): q_h = q_i[:, h, :] # [seq_len, head_dim] k_h = k_i[:, h, :] v_h = v_i[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[q_start:q_end, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 and float16) if dtype == torch.bfloat16: # Higher tolerance for bfloat16 with variable length sequences rtol, atol = 2e-2, 2e-2 elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_causal(dtype, head_dim): """Test Flash Attention with causal masking.""" torch.manual_seed(42) # Test dimensions seq_lengths = [16, 24] batch_size = len(seq_lengths) num_heads = 4 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) max_seqlen = max(seq_lengths) # Create input tensors query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention with causal mask out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, do_causal=True, scale=scale, softcapping=1.0, ) # Compute ground truth with causal mask expected = torch.zeros_like(out) for i in range(batch_size): start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item() seq_len = end - start q_i = query[start:end] k_i = key[start:end] v_i = value[start:end] # Compute attention for each head separately for h in range(num_heads): q_h = q_i[:, h, :] # [seq_len, head_dim] k_h = k_i[:, h, :] v_h = v_i[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale # Apply causal mask causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() scores.masked_fill_(causal_mask, float("-inf")) attn_weights = torch.softmax(scores, dim=-1) expected[start:end, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 and float16) if dtype == torch.bfloat16: # Higher tolerance for head_dim=128 with bfloat16 rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_gqa(dtype, head_dim): """Test Flash Attention with Grouped Query Attention.""" torch.manual_seed(42) # Test dimensions seq_len = 32 num_heads = 8 num_kv_heads = 2 # GQA with 4:1 ratio # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth with GQA # Each query head attends to its corresponding kv head (with repetition) expected = torch.zeros_like(query) gqa_factor = num_heads // num_kv_heads for h in range(num_heads): kv_h = h // gqa_factor q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, kv_h, :] v_h = value[:, kv_h, :] scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 and float16) if dtype == torch.bfloat16: # Higher tolerance for head_dim=128 with bfloat16 rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2) elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_head_dimensions(head_dim): """Test Flash Attention with different supported head dimensions.""" torch.manual_seed(42) # Test dimensions seq_len = 16 num_heads = 4 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") key = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") value = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Basic check that output is not zeros assert out.abs().max().item() > 0 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_flash_attention_large_head_dim(dtype): """Test Flash Attention with head_dim=128 specifically.""" torch.manual_seed(42) # Test dimensions with head_dim=128 seq_lengths = [32, 64] batch_size = len(seq_lengths) num_heads = 8 head_dim = 128 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) max_seqlen = max(seq_lengths) # Create input tensors query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth expected = torch.zeros_like(out) for i in range(batch_size): start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item() q_i = query[start:end] k_i = key[start:end] v_i = value[start:end] # Compute attention for each head separately for h in range(num_heads): q_h = q_i[:, h, :] # [seq_len, head_dim] k_h = k_i[:, h, :] v_h = v_i[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[start:end, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 with head_dim=128) if dtype == torch.bfloat16: # bfloat16 with head_dim=128 has known precision issues rtol, atol = 2e-2, 2e-2 elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_flash_attention_large_head_dim_causal(dtype): """Test Flash Attention with head_dim=128 and causal masking.""" torch.manual_seed(42) # Test dimensions seq_len = 48 num_heads = 4 head_dim = 128 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention with causal mask out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=True, scale=scale, softcapping=1.0, ) # Compute ground truth with causal mask expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, h, :] v_h = value[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale # Apply causal mask causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() scores.masked_fill_(causal_mask, float("-inf")) attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 with head_dim=128) if dtype == torch.bfloat16: # bfloat16 with head_dim=128 has known precision issues rtol, atol = 2e-2, 2e-2 elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) def test_flash_attention_large_head_dim_gqa(): """Test Flash Attention with head_dim=128 and GQA.""" torch.manual_seed(42) # Test dimensions seq_len = 32 num_heads = 16 num_kv_heads = 4 # GQA with 4:1 ratio head_dim = 128 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps") key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps") value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth with GQA expected = torch.zeros_like(query) gqa_factor = num_heads // num_kv_heads for h in range(num_heads): kv_h = h // gqa_factor q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, kv_h, :] v_h = value[:, kv_h, :] scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) def test_flash_attention_edge_cases(): """Test Flash Attention edge cases.""" torch.manual_seed(42) # Test 1: Single token sequence query = torch.randn(1, 1, 64, device="mps") key = torch.randn(1, 1, 64, device="mps") value = torch.randn(1, 1, 64, device="mps") cu_seqlens = create_cu_seqlens([1]) out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=1, max_seqlen_k=1, do_causal=False, scale=0.125, softcapping=1.0, ) # With single token, output should equal value torch.testing.assert_close(out, value, rtol=1e-5, atol=1e-5) # Test 2: Empty sequence in batch seq_lengths = [8, 0, 12] # Middle sequence is empty cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) query = torch.randn(total_tokens, 4, 64, device="mps") key = torch.randn(total_tokens, 4, 64, device="mps") value = torch.randn(total_tokens, 4, 64, device="mps") out = torch.empty_like(query) # This should handle empty sequences gracefully sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max(seq_lengths) if seq_lengths else 0, max_seqlen_k=max(seq_lengths) if seq_lengths else 0, do_causal=False, scale=0.125, softcapping=1.0, ) def test_flash_attention_unsupported_cases(): """Test that unsupported cases raise appropriate errors.""" # Test 1: Unsupported head dimension query = torch.randn(16, 4, 48, device="mps") # head_dim = 48 (not supported) key = torch.randn(16, 4, 48, device="mps") value = torch.randn(16, 4, 48, device="mps") cu_seqlens = create_cu_seqlens([16]) out = torch.empty_like(query) with pytest.raises(RuntimeError, match="Head dimension .* is not supported"): sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=16, max_seqlen_k=16, do_causal=False, scale=0.144, softcapping=1.0, ) # Test 2: Calling function with wrong number of arguments query = torch.randn(16, 4, 64, device="mps") key = torch.randn(16, 4, 64, device="mps") value = torch.randn(16, 4, 64, device="mps") mask = torch.randn(1, 1, 16, 16, device="mps") cu_seqlens = create_cu_seqlens([16]) out = torch.empty_like(query) # The function signature no longer accepts mask parameter with pytest.raises(TypeError): sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=16, max_seqlen_k=16, mask=mask, # This parameter doesn't exist anymore do_causal=False, scale=0.125, softcapping=1.0, ) # Test 3: Wrong dtype for cu_seqlens (should be int32) cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps") # This will silently fail (output will be unchanged) # We can detect this by initializing output to a known value out = torch.full_like(query, -999.0) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_wrong, cu_seqlens_k=cu_seqlens_wrong, max_seqlen_q=16, max_seqlen_k=16, do_causal=False, scale=0.125, softcapping=1.0, ) # Check that output wasn't modified (kernel didn't run) assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run" @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_small_sequences(dtype, head_dim): """Test Flash Attention with small sequence lengths (2-8).""" torch.manual_seed(42) # Test different small sequence lengths for seq_len in [2, 4, 6, 8]: num_heads = 4 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, h, :] v_h = value[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16) if dtype == torch.bfloat16: rtol, atol = 2e-2, 2e-2 elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_cross_attention(dtype, head_dim): """Test Flash Attention with different q_seq and k_seq (cross-attention).""" torch.manual_seed(42) # Test various q_seq, k_seq combinations test_cases = [ (16, 32), # q_seq < k_seq (32, 16), # q_seq > k_seq (8, 128), # large difference (1, 64), # single query token ] for q_seq, k_seq in test_cases: num_heads = 4 # Create cumulative sequence lengths cu_seqlens_q = create_cu_seqlens([q_seq]) cu_seqlens_k = create_cu_seqlens([k_seq]) # Create input tensors query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_seq, max_seqlen_k=k_seq, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [q_seq, head_dim] k_h = key[:, h, :] # [k_seq, head_dim] v_h = value[:, h, :] # [k_seq, head_dim] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16) if dtype == torch.bfloat16: rtol, atol = 2e-2, 2e-2 elif dtype == torch.float16: rtol, atol = 2e-3, 2e-3 else: rtol, atol = 1e-3, 1e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_flash_attention_large_sequences(dtype): """Test Flash Attention with large k_seq (>= 1024).""" torch.manual_seed(42) # Test dimensions - large k_seq to test 2-pass algorithms q_seq = 32 k_seq = 2048 # Large k_seq num_heads = 4 head_dim = 64 # Use smaller head_dim to avoid memory issues # Create cumulative sequence lengths cu_seqlens_q = create_cu_seqlens([q_seq]) cu_seqlens_k = create_cu_seqlens([k_seq]) # Create input tensors query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_seq, max_seqlen_k=k_seq, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [q_seq, head_dim] k_h = key[:, h, :] # [k_seq, head_dim] v_h = value[:, h, :] # [k_seq, head_dim] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for large sequences) if dtype == torch.bfloat16: rtol, atol = 3e-2, 3e-2 elif dtype == torch.float16: rtol, atol = 5e-3, 5e-3 else: rtol, atol = 2e-3, 2e-3 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("gqa_ratio", [2, 4, 8]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128]) def test_flash_attention_gqa_ratios(gqa_ratio, head_dim): """Test Flash Attention with different GQA ratios.""" torch.manual_seed(42) # Test dimensions seq_len = 32 num_heads = 16 num_kv_heads = num_heads // gqa_ratio dtype = torch.float32 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens([seq_len]) # Create input tensors query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Compute ground truth with GQA expected = torch.zeros_like(query) gqa_factor = num_heads // num_kv_heads for h in range(num_heads): kv_h = h // gqa_factor q_h = query[:, h, :] # [seq_len, head_dim] k_h = key[:, kv_h, :] v_h = value[:, kv_h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Check results torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) def test_flash_attention_single_query_token(): """Test Flash Attention with single query token (q_seq = 1).""" torch.manual_seed(42) # Test dimensions - single query token q_seq = 1 k_seq = 64 num_heads = 8 head_dim = 64 dtype = torch.float32 # Create cumulative sequence lengths cu_seqlens_q = create_cu_seqlens([q_seq]) cu_seqlens_k = create_cu_seqlens([k_seq]) # Create input tensors query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_seq, max_seqlen_k=k_seq, do_causal=False, scale=scale, softcapping=1.0, ) # With single token, output should be weighted average of values expected = torch.zeros_like(out) for h in range(num_heads): q_h = query[:, h, :] # [1, head_dim] k_h = key[:, h, :] # [k_seq, head_dim] v_h = value[:, h, :] # [k_seq, head_dim] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) def test_flash_attn_varlen_func(): """Test the flash_attn_varlen_func compatibility function.""" torch.manual_seed(42) # Test dimensions seq_lengths = [8, 12] num_heads = 4 head_dim = 64 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) max_seqlen = max(seq_lengths) # Create input tensors q = torch.randn(total_tokens, num_heads, head_dim, device="mps") k = torch.randn(total_tokens, num_heads, head_dim, device="mps") v = torch.randn(total_tokens, num_heads, head_dim, device="mps") # Call the compatibility function out = sdpa_flash.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, dropout_p=0.0, softmax_scale=None, # Will use 1/sqrt(head_dim) causal=False, ) # Check that output has correct shape and is not zeros assert out.shape == q.shape assert out.abs().max().item() > 0 # Test with causal out_causal = sdpa_flash.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, dropout_p=0.0, softmax_scale=0.125, causal=True, ) assert out_causal.shape == q.shape assert out_causal.abs().max().item() > 0 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) def test_flash_attention_softcapping(dtype, head_dim): """Test Flash Attention with softcapping.""" torch.manual_seed(42) # Test dimensions seq_lengths = [32, 24] num_heads = 4 softcapping = 50.0 # Create cumulative sequence lengths cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) max_seqlen = max(seq_lengths) # Create input tensors query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") # Scale factor scale = 1.0 / (head_dim ** 0.5) # Call Flash Attention with softcapping out = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, do_causal=False, scale=scale, softcapping=softcapping, ) # Compute ground truth with softcapping # The kernel applies: softmax(tanh(qk^T*scale/cap)*cap)v expected = torch.zeros_like(query) for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): q_seq = query[start:end] k_seq = key[start:end] v_seq = value[start:end] for h in range(num_heads): q_h = q_seq[:, h, :] k_h = k_seq[:, h, :] v_h = v_seq[:, h, :] # Apply softcapping formula scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * (scale / softcapping) scores = torch.tanh(scores) * softcapping attn_weights = torch.softmax(scores, dim=-1) expected[start:end, h, :] = torch.matmul(attn_weights, v_h) # Check results (higher tolerance for bfloat16 and softcapping) if dtype == torch.bfloat16: rtol, atol = 3e-2, 3e-2 elif dtype == torch.float16: rtol, atol = 2e-2, 2e-2 else: rtol, atol = 1e-2, 1e-2 torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_flash_attention_softcapping_edge_cases(dtype): """Test Flash Attention softcapping with edge cases.""" torch.manual_seed(42) # Test with softcapping = 1.0 (no softcapping) seq_len = 16 num_heads = 2 head_dim = 64 cu_seqlens = create_cu_seqlens([seq_len]) query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") scale = 1.0 / (head_dim ** 0.5) # With softcapping = 1.0 (no effect) out_no_cap = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out_no_cap, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1.0, ) # Regular computation without softcapping expected = torch.zeros_like(query) for h in range(num_heads): q_h = query[:, h, :] k_h = key[:, h, :] v_h = value[:, h, :] scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale attn_weights = torch.softmax(scores, dim=-1) expected[:, h, :] = torch.matmul(attn_weights, v_h) # Should be identical when softcapping = 1.0 rtol, atol = (2e-2, 2e-2) if dtype != torch.float32 else (1e-3, 1e-3) torch.testing.assert_close(out_no_cap, expected, rtol=rtol, atol=atol) # Test with very large softcapping value out_large_cap = torch.empty_like(query) sdpa_flash.flash_attention_varlen( out=out_large_cap, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, do_causal=False, scale=scale, softcapping=1000.0, ) # With very large softcapping, should be close to no softcapping torch.testing.assert_close(out_large_cap, expected, rtol=rtol, atol=atol)