Kernels
metal-flash-sdpa / README.md
danieldk's picture
danieldk HF Staff
Update tags
218ea7a verified
---
license: apache-2.0
tags:
- kernels
---
# Metal Flash SDPA
Optimized SDPA kernels inspired by Flash Attention for Metal.
Some components of these kernels are from [mlx](https://github.com/ml-explore/mlx).
## Supported Features
- Variable-length sequences without padding
- Causal masking
- Grouped Query Attention (GQA) and Multi-Query Attention (MQA)
- Softcapping support for attention score regularization
- Data types: `float32`, `float16`, `bfloat16`
- Head dimensions: `32`, `64`, `72`, `80`, `96`, `128`, `256`
## API Reference
### flash_attention_varlen
```python
metal_flash_sdpa.flash_attention_varlen(
out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
do_causal: bool,
scale: float,
softcapping: float
) -> None
```
- **out**: Output tensor `[total_q_tokens, num_heads, head_dim]`, modified in-place.
- **query/key/value**: Input tensors `[total_tokens, num_heads(_kv), head_dim]`.
- **cu_seqlens_q/cu_seqlens_k**: Cumulative sequence lengths (`torch.int32`), `[batch_size + 1]`.
- **max_seqlen_q/max_seqlen_k**: Maximum sequence lengths.
- **do_causal**: Enable causal masking.
- **scale**: Attention score scaling factor (e.g., `1/sqrt(head_dim)`).
- **softcapping**: Softcapping value for score regularization (use `1.0` for no softcapping).
### flash_attn_varlen_func
Compatibility wrapper matching the original Flash Attention API:
```python
out = metal_flash_sdpa.flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False
)
```