|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- kernels |
|
|
--- |
|
|
|
|
|
# Metal Flash SDPA |
|
|
|
|
|
Optimized SDPA kernels inspired by Flash Attention for Metal. |
|
|
|
|
|
Some components of these kernels are from [mlx](https://github.com/ml-explore/mlx). |
|
|
|
|
|
## Supported Features |
|
|
|
|
|
- Variable-length sequences without padding |
|
|
- Causal masking |
|
|
- Grouped Query Attention (GQA) and Multi-Query Attention (MQA) |
|
|
- Softcapping support for attention score regularization |
|
|
- Data types: `float32`, `float16`, `bfloat16` |
|
|
- Head dimensions: `32`, `64`, `72`, `80`, `96`, `128`, `256` |
|
|
|
|
|
## API Reference |
|
|
|
|
|
### flash_attention_varlen |
|
|
|
|
|
```python |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out: torch.Tensor, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
cu_seqlens_q: torch.Tensor, |
|
|
cu_seqlens_k: torch.Tensor, |
|
|
max_seqlen_q: int, |
|
|
max_seqlen_k: int, |
|
|
do_causal: bool, |
|
|
scale: float, |
|
|
softcapping: float |
|
|
) -> None |
|
|
``` |
|
|
|
|
|
- **out**: Output tensor `[total_q_tokens, num_heads, head_dim]`, modified in-place. |
|
|
- **query/key/value**: Input tensors `[total_tokens, num_heads(_kv), head_dim]`. |
|
|
- **cu_seqlens_q/cu_seqlens_k**: Cumulative sequence lengths (`torch.int32`), `[batch_size + 1]`. |
|
|
- **max_seqlen_q/max_seqlen_k**: Maximum sequence lengths. |
|
|
- **do_causal**: Enable causal masking. |
|
|
- **scale**: Attention score scaling factor (e.g., `1/sqrt(head_dim)`). |
|
|
- **softcapping**: Softcapping value for score regularization (use `1.0` for no softcapping). |
|
|
|
|
|
### flash_attn_varlen_func |
|
|
|
|
|
Compatibility wrapper matching the original Flash Attention API: |
|
|
|
|
|
```python |
|
|
out = metal_flash_sdpa.flash_attn_varlen_func( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
cu_seqlens_q: torch.Tensor, |
|
|
cu_seqlens_k: torch.Tensor, |
|
|
max_seqlen_q: int, |
|
|
max_seqlen_k: int, |
|
|
dropout_p: float = 0.0, |
|
|
softmax_scale: Optional[float] = None, |
|
|
causal: bool = False, |
|
|
window_size: Tuple[int, int] = (-1, -1), |
|
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
|
deterministic: bool = False, |
|
|
return_attn_probs: bool = False |
|
|
) |
|
|
``` |
|
|
|