#!/usr/bin/env python3 """Benchmark script for metal-sdpa-flash (Flash SDPA)""" import torch import time import sdpa_flash from typing import List, Tuple import numpy as np def create_cu_seqlens(seq_lengths: List[int]) -> torch.Tensor: """Create cumulative sequence lengths tensor.""" cu_seqlens = [0] for length in seq_lengths: cu_seqlens.append(cu_seqlens[-1] + length) return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps") def warmup(func, *args, num_warmup=10): """Warmup the GPU by running the function multiple times""" for _ in range(num_warmup): func(*args) torch.mps.synchronize() def benchmark_flash_sdpa( batch_size: int, num_heads: int, seq_len: int, head_dim: int, dtype: torch.dtype, causal: bool = False, num_iterations: int = 100, ) -> float: """Benchmark Flash SDPA with given parameters""" # Create sequence lengths (all equal for fair comparison) seq_lengths = [seq_len] * batch_size cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) # Create input tensors in Flash format (total_tokens, num_heads, head_dim) query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") out = torch.empty_like(query) scale = 1.0 / (head_dim ** 0.5) # Define the function to benchmark def run_flash_sdpa(): sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, mask=None, do_causal=causal, scale=scale, softcapping=1.0, ) # Warmup warmup(run_flash_sdpa, num_warmup=10) # Benchmark torch.mps.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): run_flash_sdpa() torch.mps.synchronize() end_time = time.perf_counter() avg_time_ms = (end_time - start_time) * 1000 / num_iterations return avg_time_ms def benchmark_flash_gqa( batch_size: int, num_heads_q: int, num_heads_kv: int, seq_len: int, head_dim: int, dtype: torch.dtype, causal: bool = False, num_iterations: int = 100, ) -> float: """Benchmark Flash Attention with Grouped Query Attention""" # Create sequence lengths seq_lengths = [seq_len] * batch_size cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) # Create input tensors with different head counts query = torch.randn(total_tokens, num_heads_q, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads_kv, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads_kv, head_dim, dtype=dtype, device="mps") out = torch.empty_like(query) scale = 1.0 / (head_dim ** 0.5) # Define the function to benchmark def run_flash_gqa(): sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, mask=None, do_causal=causal, scale=scale, softcapping=1.0, ) # Warmup warmup(run_flash_gqa, num_warmup=10) # Benchmark torch.mps.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): run_flash_gqa() torch.mps.synchronize() end_time = time.perf_counter() avg_time_ms = (end_time - start_time) * 1000 / num_iterations return avg_time_ms def benchmark_variable_length( seq_lengths: List[int], num_heads: int, head_dim: int, dtype: torch.dtype, causal: bool = False, num_iterations: int = 100, ) -> float: """Benchmark Flash Attention with variable sequence lengths""" cu_seqlens = create_cu_seqlens(seq_lengths) total_tokens = sum(seq_lengths) max_seqlen = max(seq_lengths) # Create input tensors query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") out = torch.empty_like(query) scale = 1.0 / (head_dim ** 0.5) # Define the function to benchmark def run_varlen(): sdpa_flash.flash_attention_varlen( out=out, query=query, key=key, value=value, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, mask=None, do_causal=causal, scale=scale, softcapping=1.0, ) # Warmup warmup(run_varlen, num_warmup=10) # Benchmark torch.mps.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): run_varlen() torch.mps.synchronize() end_time = time.perf_counter() avg_time_ms = (end_time - start_time) * 1000 / num_iterations return avg_time_ms def main(): print("=" * 80) print("Metal Flash SDPA Benchmark") print("=" * 80) # Test configurations (matching the plain SDPA benchmark) configs = [ # (batch_size, num_heads, seq_len, head_dim, dtype, causal, name) (1, 32, 512, 64, torch.float32, False, "Small seq, float32"), (1, 32, 512, 64, torch.float16, False, "Small seq, float16"), (1, 32, 512, 64, torch.bfloat16, False, "Small seq, bfloat16"), (4, 32, 2048, 64, torch.float16, False, "Medium seq, float16"), (4, 32, 2048, 64, torch.float16, True, "Medium seq, float16, causal"), (2, 32, 4096, 64, torch.float16, False, "Large seq, float16"), (2, 32, 4096, 64, torch.float16, True, "Large seq, float16, causal"), # Different head dimensions (2, 32, 2048, 32, torch.float16, False, "head_dim=32"), (2, 32, 2048, 64, torch.float16, False, "head_dim=64"), (2, 32, 2048, 128, torch.float16, False, "head_dim=128"), # Vector kernel cases (q_seq=1) - Flash doesn't have a special vector kernel # but we benchmark these cases for fair comparison with plain SDPA (16, 32, 1, 64, torch.float16, False, "Vector kernel (q_seq=1)"), (16, 32, 1, 128, torch.float16, False, "Vector kernel (q_seq=1, head_dim=128)"), ] print("\nFlash Attention Benchmarks:") print("-" * 80) print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}") print("-" * 80) for batch_size, num_heads, seq_len, head_dim, dtype, causal, name in configs: time_ms = benchmark_flash_sdpa( batch_size, num_heads, seq_len, head_dim, dtype, causal ) # Calculate FLOPS (approximate) # Attention: 2 * batch * heads * seq_len^2 * head_dim flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim tflops = (flops / 1e12) / (time_ms / 1000) print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}") # GQA benchmarks print("\n\nGrouped Query Attention (GQA) Benchmarks:") print("-" * 80) print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}") print("-" * 80) gqa_configs = [ # (batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal, name) (2, 32, 8, 2048, 64, torch.float16, False, "GQA 4:1 ratio"), (2, 32, 4, 2048, 64, torch.float16, False, "GQA 8:1 ratio"), (2, 32, 1, 2048, 64, torch.float16, False, "MQA (32:1 ratio)"), (2, 32, 8, 2048, 128, torch.float16, False, "GQA 4:1, head_dim=128"), ] for batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal, name in gqa_configs: time_ms = benchmark_flash_gqa( batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal ) # Calculate FLOPS for GQA flops = 2 * batch_size * num_heads_q * seq_len * seq_len * head_dim tflops = (flops / 1e12) / (time_ms / 1000) print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}") # Variable length sequences (unique to Flash Attention) print("\n\nVariable Length Sequence Benchmarks:") print("-" * 80) print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}") print("-" * 80) varlen_configs = [ # (seq_lengths, num_heads, head_dim, dtype, causal, name) ([512, 1024, 2048, 4096], 32, 64, torch.float16, False, "Variable [512-4096]"), ([128, 256, 512, 1024, 2048], 32, 64, torch.float16, False, "Variable [128-2048]"), ([2048, 2048, 2048, 2048], 32, 64, torch.float16, False, "Fixed 4x2048 (baseline)"), ] for seq_lengths, num_heads, head_dim, dtype, causal, name in varlen_configs: time_ms = benchmark_variable_length( seq_lengths, num_heads, head_dim, dtype, causal ) # Calculate FLOPS for variable length total_flops = 0 for seq_len in seq_lengths: total_flops += 2 * num_heads * seq_len * seq_len * head_dim tflops = (total_flops / 1e12) / (time_ms / 1000) print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}") print("\n" + "=" * 80) print("Benchmark completed!") if __name__ == "__main__": main()