Add build files
Browse files- .gitattributes +2 -0
- README.md +4 -4
- benchmark_flash_sdpa.py +4 -4
- build.toml +1 -1
- {torch-ext/sdpa_flash → build/torch27-metal-aarch64-darwin/metal_flash_sdpa}/__init__.py +0 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/_custom_ops.cpython-312.pyc +0 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/_ops.cpython-312.pyc +0 -0
- {torch-ext/sdpa_flash → build/torch27-metal-aarch64-darwin/metal_flash_sdpa}/_custom_ops.py +0 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_metal_flash_sdpa_032c946.abi3.so +3 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_metal_flash_sdpa_032c946.metallib +3 -0
- build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_ops.py +9 -0
- flake.lock +168 -0
- flake.nix +1 -1
- tests/test_flash_attention.py +24 -24
- torch-ext/metal_flash_sdpa/__init__.py +11 -0
- torch-ext/metal_flash_sdpa/_custom_ops.py +117 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.metallib filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -4,9 +4,9 @@ tags:
|
|
| 4 |
- kernel
|
| 5 |
---
|
| 6 |
|
| 7 |
-
# Metal Flash
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
## Supported Features
|
| 12 |
|
|
@@ -22,7 +22,7 @@ A PyTorch extension that provides optimized Metal implementations of Flash Atten
|
|
| 22 |
### flash_attention_varlen
|
| 23 |
|
| 24 |
```python
|
| 25 |
-
|
| 26 |
out: torch.Tensor,
|
| 27 |
query: torch.Tensor,
|
| 28 |
key: torch.Tensor,
|
|
@@ -50,7 +50,7 @@ sdpa_flash.flash_attention_varlen(
|
|
| 50 |
Compatibility wrapper matching the original Flash Attention API:
|
| 51 |
|
| 52 |
```python
|
| 53 |
-
out =
|
| 54 |
q: torch.Tensor,
|
| 55 |
k: torch.Tensor,
|
| 56 |
v: torch.Tensor,
|
|
|
|
| 4 |
- kernel
|
| 5 |
---
|
| 6 |
|
| 7 |
+
# Metal Flash SDPA
|
| 8 |
|
| 9 |
+
Optimized SDPA kernels inspired by Flash Attention for Metal.
|
| 10 |
|
| 11 |
## Supported Features
|
| 12 |
|
|
|
|
| 22 |
### flash_attention_varlen
|
| 23 |
|
| 24 |
```python
|
| 25 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 26 |
out: torch.Tensor,
|
| 27 |
query: torch.Tensor,
|
| 28 |
key: torch.Tensor,
|
|
|
|
| 50 |
Compatibility wrapper matching the original Flash Attention API:
|
| 51 |
|
| 52 |
```python
|
| 53 |
+
out = metal_flash_sdpa.flash_attn_varlen_func(
|
| 54 |
q: torch.Tensor,
|
| 55 |
k: torch.Tensor,
|
| 56 |
v: torch.Tensor,
|
benchmark_flash_sdpa.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import time
|
| 6 |
-
import
|
| 7 |
from typing import List, Tuple
|
| 8 |
import numpy as np
|
| 9 |
|
|
@@ -49,7 +49,7 @@ def benchmark_flash_sdpa(
|
|
| 49 |
|
| 50 |
# Define the function to benchmark
|
| 51 |
def run_flash_sdpa():
|
| 52 |
-
|
| 53 |
out=out,
|
| 54 |
query=query,
|
| 55 |
key=key,
|
|
@@ -108,7 +108,7 @@ def benchmark_flash_gqa(
|
|
| 108 |
|
| 109 |
# Define the function to benchmark
|
| 110 |
def run_flash_gqa():
|
| 111 |
-
|
| 112 |
out=out,
|
| 113 |
query=query,
|
| 114 |
key=key,
|
|
@@ -164,7 +164,7 @@ def benchmark_variable_length(
|
|
| 164 |
|
| 165 |
# Define the function to benchmark
|
| 166 |
def run_varlen():
|
| 167 |
-
|
| 168 |
out=out,
|
| 169 |
query=query,
|
| 170 |
key=key,
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import time
|
| 6 |
+
import metal_flash_sdpa
|
| 7 |
from typing import List, Tuple
|
| 8 |
import numpy as np
|
| 9 |
|
|
|
|
| 49 |
|
| 50 |
# Define the function to benchmark
|
| 51 |
def run_flash_sdpa():
|
| 52 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 53 |
out=out,
|
| 54 |
query=query,
|
| 55 |
key=key,
|
|
|
|
| 108 |
|
| 109 |
# Define the function to benchmark
|
| 110 |
def run_flash_gqa():
|
| 111 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 112 |
out=out,
|
| 113 |
query=query,
|
| 114 |
key=key,
|
|
|
|
| 164 |
|
| 165 |
# Define the function to benchmark
|
| 166 |
def run_varlen():
|
| 167 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 168 |
out=out,
|
| 169 |
query=query,
|
| 170 |
key=key,
|
build.toml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
[general]
|
| 2 |
-
name = "
|
| 3 |
universal = false
|
| 4 |
|
| 5 |
[torch]
|
|
|
|
| 1 |
[general]
|
| 2 |
+
name = "metal_flash_sdpa"
|
| 3 |
universal = false
|
| 4 |
|
| 5 |
[torch]
|
{torch-ext/sdpa_flash → build/torch27-metal-aarch64-darwin/metal_flash_sdpa}/__init__.py
RENAMED
|
File without changes
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/_custom_ops.cpython-312.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/__pycache__/_ops.cpython-312.pyc
ADDED
|
Binary file (595 Bytes). View file
|
|
|
{torch-ext/sdpa_flash → build/torch27-metal-aarch64-darwin/metal_flash_sdpa}/_custom_ops.py
RENAMED
|
File without changes
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_metal_flash_sdpa_032c946.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b16d8835b2d0e86339095de249268b1f28ce41ba4dedca70326226e6267f8354
|
| 3 |
+
size 104672
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_metal_flash_sdpa_032c946.metallib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8ae6e3b9eb9fb2d3a4d86983308d23009ee03097fa0668d973e561fb235110e
|
| 3 |
+
size 622095
|
build/torch27-metal-aarch64-darwin/metal_flash_sdpa/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _metal_flash_sdpa_032c946
|
| 3 |
+
ops = torch.ops._metal_flash_sdpa_032c946
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_metal_flash_sdpa_032c946::{op_name}"
|
flake.lock
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1747046372,
|
| 6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1733328505,
|
| 21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1751968576,
|
| 77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1752505139,
|
| 102 |
+
"narHash": "sha256-gdIuOhU/adUjNGNgIk1cDTfN7J2tH0UuHSU3FanFfxE=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "a5cebbc02f01a9d359d18ceb9e8bdadead2a289a",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"repo": "kernel-builder",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nixpkgs": {
|
| 115 |
+
"locked": {
|
| 116 |
+
"lastModified": 1747820358,
|
| 117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
| 118 |
+
"owner": "danieldk",
|
| 119 |
+
"repo": "nixpkgs",
|
| 120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
| 121 |
+
"type": "github"
|
| 122 |
+
},
|
| 123 |
+
"original": {
|
| 124 |
+
"owner": "danieldk",
|
| 125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
| 126 |
+
"repo": "nixpkgs",
|
| 127 |
+
"type": "github"
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"root": {
|
| 131 |
+
"inputs": {
|
| 132 |
+
"kernel-builder": "kernel-builder"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"systems": {
|
| 136 |
+
"locked": {
|
| 137 |
+
"lastModified": 1681028828,
|
| 138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
+
"owner": "nix-systems",
|
| 140 |
+
"repo": "default",
|
| 141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
+
"type": "github"
|
| 143 |
+
},
|
| 144 |
+
"original": {
|
| 145 |
+
"owner": "nix-systems",
|
| 146 |
+
"repo": "default",
|
| 147 |
+
"type": "github"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"systems_2": {
|
| 151 |
+
"locked": {
|
| 152 |
+
"lastModified": 1681028828,
|
| 153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
+
"owner": "nix-systems",
|
| 155 |
+
"repo": "default",
|
| 156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
+
"type": "github"
|
| 158 |
+
},
|
| 159 |
+
"original": {
|
| 160 |
+
"owner": "nix-systems",
|
| 161 |
+
"repo": "default",
|
| 162 |
+
"type": "github"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"root": "root",
|
| 167 |
+
"version": 7
|
| 168 |
+
}
|
flake.nix
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
description = "Flake for SDPA kernel";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
-
kernel-builder.url = "
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
|
|
|
| 2 |
description = "Flake for SDPA kernel";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
tests/test_flash_attention.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import pytest
|
| 3 |
-
import
|
| 4 |
|
| 5 |
|
| 6 |
def create_cu_seqlens(seq_lengths):
|
|
@@ -34,7 +34,7 @@ def test_flash_attention_single_sequence(dtype, head_dim):
|
|
| 34 |
|
| 35 |
# Call Flash Attention
|
| 36 |
out = torch.empty_like(query)
|
| 37 |
-
|
| 38 |
out=out,
|
| 39 |
query=query,
|
| 40 |
key=key,
|
|
@@ -102,7 +102,7 @@ def test_flash_attention_variable_lengths(dtype, head_dim):
|
|
| 102 |
|
| 103 |
# Call Flash Attention
|
| 104 |
out = torch.empty_like(query)
|
| 105 |
-
|
| 106 |
out=out,
|
| 107 |
query=query,
|
| 108 |
key=key,
|
|
@@ -173,7 +173,7 @@ def test_flash_attention_causal(dtype, head_dim):
|
|
| 173 |
|
| 174 |
# Call Flash Attention with causal mask
|
| 175 |
out = torch.empty_like(query)
|
| 176 |
-
|
| 177 |
out=out,
|
| 178 |
query=query,
|
| 179 |
key=key,
|
|
@@ -247,7 +247,7 @@ def test_flash_attention_gqa(dtype, head_dim):
|
|
| 247 |
|
| 248 |
# Call Flash Attention
|
| 249 |
out = torch.empty_like(query)
|
| 250 |
-
|
| 251 |
out=out,
|
| 252 |
query=query,
|
| 253 |
key=key,
|
|
@@ -309,7 +309,7 @@ def test_flash_attention_head_dimensions(head_dim):
|
|
| 309 |
|
| 310 |
# Call Flash Attention
|
| 311 |
out = torch.empty_like(query)
|
| 312 |
-
|
| 313 |
out=out,
|
| 314 |
query=query,
|
| 315 |
key=key,
|
|
@@ -353,7 +353,7 @@ def test_flash_attention_large_head_dim(dtype):
|
|
| 353 |
|
| 354 |
# Call Flash Attention
|
| 355 |
out = torch.empty_like(query)
|
| 356 |
-
|
| 357 |
out=out,
|
| 358 |
query=query,
|
| 359 |
key=key,
|
|
@@ -420,7 +420,7 @@ def test_flash_attention_large_head_dim_causal(dtype):
|
|
| 420 |
|
| 421 |
# Call Flash Attention with causal mask
|
| 422 |
out = torch.empty_like(query)
|
| 423 |
-
|
| 424 |
out=out,
|
| 425 |
query=query,
|
| 426 |
key=key,
|
|
@@ -485,7 +485,7 @@ def test_flash_attention_large_head_dim_gqa():
|
|
| 485 |
|
| 486 |
# Call Flash Attention
|
| 487 |
out = torch.empty_like(query)
|
| 488 |
-
|
| 489 |
out=out,
|
| 490 |
query=query,
|
| 491 |
key=key,
|
|
@@ -528,7 +528,7 @@ def test_flash_attention_edge_cases():
|
|
| 528 |
cu_seqlens = create_cu_seqlens([1])
|
| 529 |
out = torch.empty_like(query)
|
| 530 |
|
| 531 |
-
|
| 532 |
out=out,
|
| 533 |
query=query,
|
| 534 |
key=key,
|
|
@@ -556,7 +556,7 @@ def test_flash_attention_edge_cases():
|
|
| 556 |
out = torch.empty_like(query)
|
| 557 |
|
| 558 |
# This should handle empty sequences gracefully
|
| 559 |
-
|
| 560 |
out=out,
|
| 561 |
query=query,
|
| 562 |
key=key,
|
|
@@ -582,7 +582,7 @@ def test_flash_attention_unsupported_cases():
|
|
| 582 |
out = torch.empty_like(query)
|
| 583 |
|
| 584 |
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"):
|
| 585 |
-
|
| 586 |
out=out,
|
| 587 |
query=query,
|
| 588 |
key=key,
|
|
@@ -606,7 +606,7 @@ def test_flash_attention_unsupported_cases():
|
|
| 606 |
|
| 607 |
# The function signature no longer accepts mask parameter
|
| 608 |
with pytest.raises(TypeError):
|
| 609 |
-
|
| 610 |
out=out,
|
| 611 |
query=query,
|
| 612 |
key=key,
|
|
@@ -627,7 +627,7 @@ def test_flash_attention_unsupported_cases():
|
|
| 627 |
# This will silently fail (output will be unchanged)
|
| 628 |
# We can detect this by initializing output to a known value
|
| 629 |
out = torch.full_like(query, -999.0)
|
| 630 |
-
|
| 631 |
out=out,
|
| 632 |
query=query,
|
| 633 |
key=key,
|
|
@@ -668,7 +668,7 @@ def test_flash_attention_small_sequences(dtype, head_dim):
|
|
| 668 |
|
| 669 |
# Call Flash Attention
|
| 670 |
out = torch.empty_like(query)
|
| 671 |
-
|
| 672 |
out=out,
|
| 673 |
query=query,
|
| 674 |
key=key,
|
|
@@ -734,7 +734,7 @@ def test_flash_attention_cross_attention(dtype, head_dim):
|
|
| 734 |
|
| 735 |
# Call Flash Attention
|
| 736 |
out = torch.empty_like(query)
|
| 737 |
-
|
| 738 |
out=out,
|
| 739 |
query=query,
|
| 740 |
key=key,
|
|
@@ -794,7 +794,7 @@ def test_flash_attention_large_sequences(dtype):
|
|
| 794 |
|
| 795 |
# Call Flash Attention
|
| 796 |
out = torch.empty_like(query)
|
| 797 |
-
|
| 798 |
out=out,
|
| 799 |
query=query,
|
| 800 |
key=key,
|
|
@@ -854,7 +854,7 @@ def test_flash_attention_gqa_ratios(gqa_ratio, head_dim):
|
|
| 854 |
|
| 855 |
# Call Flash Attention
|
| 856 |
out = torch.empty_like(query)
|
| 857 |
-
|
| 858 |
out=out,
|
| 859 |
query=query,
|
| 860 |
key=key,
|
|
@@ -911,7 +911,7 @@ def test_flash_attention_single_query_token():
|
|
| 911 |
|
| 912 |
# Call Flash Attention
|
| 913 |
out = torch.empty_like(query)
|
| 914 |
-
|
| 915 |
out=out,
|
| 916 |
query=query,
|
| 917 |
key=key,
|
|
@@ -959,7 +959,7 @@ def test_flash_attn_varlen_func():
|
|
| 959 |
v = torch.randn(total_tokens, num_heads, head_dim, device="mps")
|
| 960 |
|
| 961 |
# Call the compatibility function
|
| 962 |
-
out =
|
| 963 |
q=q,
|
| 964 |
k=k,
|
| 965 |
v=v,
|
|
@@ -977,7 +977,7 @@ def test_flash_attn_varlen_func():
|
|
| 977 |
assert out.abs().max().item() > 0
|
| 978 |
|
| 979 |
# Test with causal
|
| 980 |
-
out_causal =
|
| 981 |
q=q,
|
| 982 |
k=k,
|
| 983 |
v=v,
|
|
@@ -1020,7 +1020,7 @@ def test_flash_attention_softcapping(dtype, head_dim):
|
|
| 1020 |
|
| 1021 |
# Call Flash Attention with softcapping
|
| 1022 |
out = torch.empty_like(query)
|
| 1023 |
-
|
| 1024 |
out=out,
|
| 1025 |
query=query,
|
| 1026 |
key=key,
|
|
@@ -1083,7 +1083,7 @@ def test_flash_attention_softcapping_edge_cases(dtype):
|
|
| 1083 |
|
| 1084 |
# With softcapping = 1.0 (no effect)
|
| 1085 |
out_no_cap = torch.empty_like(query)
|
| 1086 |
-
|
| 1087 |
out=out_no_cap,
|
| 1088 |
query=query,
|
| 1089 |
key=key,
|
|
@@ -1114,7 +1114,7 @@ def test_flash_attention_softcapping_edge_cases(dtype):
|
|
| 1114 |
|
| 1115 |
# Test with very large softcapping value
|
| 1116 |
out_large_cap = torch.empty_like(query)
|
| 1117 |
-
|
| 1118 |
out=out_large_cap,
|
| 1119 |
query=query,
|
| 1120 |
key=key,
|
|
|
|
| 1 |
import torch
|
| 2 |
import pytest
|
| 3 |
+
import metal_flash_sdpa
|
| 4 |
|
| 5 |
|
| 6 |
def create_cu_seqlens(seq_lengths):
|
|
|
|
| 34 |
|
| 35 |
# Call Flash Attention
|
| 36 |
out = torch.empty_like(query)
|
| 37 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 38 |
out=out,
|
| 39 |
query=query,
|
| 40 |
key=key,
|
|
|
|
| 102 |
|
| 103 |
# Call Flash Attention
|
| 104 |
out = torch.empty_like(query)
|
| 105 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 106 |
out=out,
|
| 107 |
query=query,
|
| 108 |
key=key,
|
|
|
|
| 173 |
|
| 174 |
# Call Flash Attention with causal mask
|
| 175 |
out = torch.empty_like(query)
|
| 176 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 177 |
out=out,
|
| 178 |
query=query,
|
| 179 |
key=key,
|
|
|
|
| 247 |
|
| 248 |
# Call Flash Attention
|
| 249 |
out = torch.empty_like(query)
|
| 250 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 251 |
out=out,
|
| 252 |
query=query,
|
| 253 |
key=key,
|
|
|
|
| 309 |
|
| 310 |
# Call Flash Attention
|
| 311 |
out = torch.empty_like(query)
|
| 312 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 313 |
out=out,
|
| 314 |
query=query,
|
| 315 |
key=key,
|
|
|
|
| 353 |
|
| 354 |
# Call Flash Attention
|
| 355 |
out = torch.empty_like(query)
|
| 356 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 357 |
out=out,
|
| 358 |
query=query,
|
| 359 |
key=key,
|
|
|
|
| 420 |
|
| 421 |
# Call Flash Attention with causal mask
|
| 422 |
out = torch.empty_like(query)
|
| 423 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 424 |
out=out,
|
| 425 |
query=query,
|
| 426 |
key=key,
|
|
|
|
| 485 |
|
| 486 |
# Call Flash Attention
|
| 487 |
out = torch.empty_like(query)
|
| 488 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 489 |
out=out,
|
| 490 |
query=query,
|
| 491 |
key=key,
|
|
|
|
| 528 |
cu_seqlens = create_cu_seqlens([1])
|
| 529 |
out = torch.empty_like(query)
|
| 530 |
|
| 531 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 532 |
out=out,
|
| 533 |
query=query,
|
| 534 |
key=key,
|
|
|
|
| 556 |
out = torch.empty_like(query)
|
| 557 |
|
| 558 |
# This should handle empty sequences gracefully
|
| 559 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 560 |
out=out,
|
| 561 |
query=query,
|
| 562 |
key=key,
|
|
|
|
| 582 |
out = torch.empty_like(query)
|
| 583 |
|
| 584 |
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"):
|
| 585 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 586 |
out=out,
|
| 587 |
query=query,
|
| 588 |
key=key,
|
|
|
|
| 606 |
|
| 607 |
# The function signature no longer accepts mask parameter
|
| 608 |
with pytest.raises(TypeError):
|
| 609 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 610 |
out=out,
|
| 611 |
query=query,
|
| 612 |
key=key,
|
|
|
|
| 627 |
# This will silently fail (output will be unchanged)
|
| 628 |
# We can detect this by initializing output to a known value
|
| 629 |
out = torch.full_like(query, -999.0)
|
| 630 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 631 |
out=out,
|
| 632 |
query=query,
|
| 633 |
key=key,
|
|
|
|
| 668 |
|
| 669 |
# Call Flash Attention
|
| 670 |
out = torch.empty_like(query)
|
| 671 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 672 |
out=out,
|
| 673 |
query=query,
|
| 674 |
key=key,
|
|
|
|
| 734 |
|
| 735 |
# Call Flash Attention
|
| 736 |
out = torch.empty_like(query)
|
| 737 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 738 |
out=out,
|
| 739 |
query=query,
|
| 740 |
key=key,
|
|
|
|
| 794 |
|
| 795 |
# Call Flash Attention
|
| 796 |
out = torch.empty_like(query)
|
| 797 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 798 |
out=out,
|
| 799 |
query=query,
|
| 800 |
key=key,
|
|
|
|
| 854 |
|
| 855 |
# Call Flash Attention
|
| 856 |
out = torch.empty_like(query)
|
| 857 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 858 |
out=out,
|
| 859 |
query=query,
|
| 860 |
key=key,
|
|
|
|
| 911 |
|
| 912 |
# Call Flash Attention
|
| 913 |
out = torch.empty_like(query)
|
| 914 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 915 |
out=out,
|
| 916 |
query=query,
|
| 917 |
key=key,
|
|
|
|
| 959 |
v = torch.randn(total_tokens, num_heads, head_dim, device="mps")
|
| 960 |
|
| 961 |
# Call the compatibility function
|
| 962 |
+
out = metal_flash_sdpa.flash_attn_varlen_func(
|
| 963 |
q=q,
|
| 964 |
k=k,
|
| 965 |
v=v,
|
|
|
|
| 977 |
assert out.abs().max().item() > 0
|
| 978 |
|
| 979 |
# Test with causal
|
| 980 |
+
out_causal = metal_flash_sdpa.flash_attn_varlen_func(
|
| 981 |
q=q,
|
| 982 |
k=k,
|
| 983 |
v=v,
|
|
|
|
| 1020 |
|
| 1021 |
# Call Flash Attention with softcapping
|
| 1022 |
out = torch.empty_like(query)
|
| 1023 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 1024 |
out=out,
|
| 1025 |
query=query,
|
| 1026 |
key=key,
|
|
|
|
| 1083 |
|
| 1084 |
# With softcapping = 1.0 (no effect)
|
| 1085 |
out_no_cap = torch.empty_like(query)
|
| 1086 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 1087 |
out=out_no_cap,
|
| 1088 |
query=query,
|
| 1089 |
key=key,
|
|
|
|
| 1114 |
|
| 1115 |
# Test with very large softcapping value
|
| 1116 |
out_large_cap = torch.empty_like(query)
|
| 1117 |
+
metal_flash_sdpa.flash_attention_varlen(
|
| 1118 |
out=out_large_cap,
|
| 1119 |
query=query,
|
| 1120 |
key=key,
|
torch-ext/metal_flash_sdpa/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._custom_ops import (
|
| 2 |
+
flash_attention_varlen,
|
| 3 |
+
flash_attn_varlen_func,
|
| 4 |
+
)
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"flash_attention_varlen",
|
| 9 |
+
"flash_attn_varlen_func",
|
| 10 |
+
"ops",
|
| 11 |
+
]
|
torch-ext/metal_flash_sdpa/_custom_ops.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def flash_attention_varlen(
|
| 9 |
+
out: torch.Tensor,
|
| 10 |
+
query: torch.Tensor,
|
| 11 |
+
key: torch.Tensor,
|
| 12 |
+
value: torch.Tensor,
|
| 13 |
+
cu_seqlens_q: torch.Tensor,
|
| 14 |
+
cu_seqlens_k: torch.Tensor,
|
| 15 |
+
max_seqlen_q: int,
|
| 16 |
+
max_seqlen_k: int,
|
| 17 |
+
do_causal: bool = False,
|
| 18 |
+
scale: Optional[float] = None,
|
| 19 |
+
softcapping: float = 1.0,
|
| 20 |
+
) -> None:
|
| 21 |
+
"""
|
| 22 |
+
Flash Attention with variable-length sequences.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
out: Output tensor of shape [total_q_tokens, num_heads, head_dim]
|
| 26 |
+
query: Query tensor of shape [total_q_tokens, num_heads, head_dim]
|
| 27 |
+
key: Key tensor of shape [total_k_tokens, num_heads_kv, head_dim]
|
| 28 |
+
value: Value tensor of shape [total_k_tokens, num_heads_kv, head_dim]
|
| 29 |
+
cu_seqlens_q: Cumulative sequence lengths for queries, shape [batch_size + 1], dtype must be torch.int32
|
| 30 |
+
cu_seqlens_k: Cumulative sequence lengths for keys, shape [batch_size + 1], dtype must be torch.int32
|
| 31 |
+
max_seqlen_q: Maximum sequence length in the query batch
|
| 32 |
+
max_seqlen_k: Maximum sequence length in the key batch
|
| 33 |
+
do_causal: Whether to apply causal masking
|
| 34 |
+
scale: Attention scale factor (default: 1/sqrt(head_dim))
|
| 35 |
+
softcapping: Softcapping value (default: 1.0, must be 1.0 for this implementation)
|
| 36 |
+
|
| 37 |
+
Note:
|
| 38 |
+
- cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
|
| 39 |
+
- Supported head dimensions: 32, 64, 72, 80, 96, 128
|
| 40 |
+
- Masks are not supported
|
| 41 |
+
"""
|
| 42 |
+
if scale is None:
|
| 43 |
+
scale = query.shape[-1] ** -0.5
|
| 44 |
+
|
| 45 |
+
ops.flash_attention_varlen(
|
| 46 |
+
out,
|
| 47 |
+
query,
|
| 48 |
+
key,
|
| 49 |
+
value,
|
| 50 |
+
cu_seqlens_q,
|
| 51 |
+
cu_seqlens_k,
|
| 52 |
+
max_seqlen_q,
|
| 53 |
+
max_seqlen_k,
|
| 54 |
+
do_causal,
|
| 55 |
+
scale,
|
| 56 |
+
softcapping,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def flash_attn_varlen_func(
|
| 60 |
+
q: torch.Tensor,
|
| 61 |
+
k: torch.Tensor,
|
| 62 |
+
v: torch.Tensor,
|
| 63 |
+
cu_seqlens_q: torch.Tensor,
|
| 64 |
+
cu_seqlens_k: torch.Tensor,
|
| 65 |
+
max_seqlen_q: int,
|
| 66 |
+
max_seqlen_k: int,
|
| 67 |
+
dropout_p: float = 0.0,
|
| 68 |
+
softmax_scale: Optional[float] = None,
|
| 69 |
+
causal: bool = False,
|
| 70 |
+
window_size: tuple = (-1, -1),
|
| 71 |
+
alibi_slopes: Optional[torch.Tensor] = None,
|
| 72 |
+
deterministic: bool = False,
|
| 73 |
+
return_attn_probs: bool = False,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Flash Attention function with API compatible with the original Flash Attention.
|
| 77 |
+
|
| 78 |
+
Note: This implementation does not support:
|
| 79 |
+
- dropout
|
| 80 |
+
- window attention
|
| 81 |
+
- alibi slopes
|
| 82 |
+
- returning attention probabilities
|
| 83 |
+
"""
|
| 84 |
+
if dropout_p > 0:
|
| 85 |
+
raise NotImplementedError("Dropout is not supported in this implementation")
|
| 86 |
+
if window_size != (-1, -1):
|
| 87 |
+
raise NotImplementedError("Window attention is not supported")
|
| 88 |
+
if alibi_slopes is not None:
|
| 89 |
+
raise NotImplementedError("ALiBi is not supported")
|
| 90 |
+
if return_attn_probs:
|
| 91 |
+
raise NotImplementedError("Returning attention probabilities is not supported")
|
| 92 |
+
|
| 93 |
+
# Create output tensor
|
| 94 |
+
out = torch.empty_like(q)
|
| 95 |
+
|
| 96 |
+
# Call the kernel
|
| 97 |
+
flash_attention_varlen(
|
| 98 |
+
out=out,
|
| 99 |
+
query=q,
|
| 100 |
+
key=k,
|
| 101 |
+
value=v,
|
| 102 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 103 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 104 |
+
max_seqlen_q=max_seqlen_q,
|
| 105 |
+
max_seqlen_k=max_seqlen_k,
|
| 106 |
+
do_causal=causal,
|
| 107 |
+
scale=softmax_scale,
|
| 108 |
+
softcapping=1.0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = [
|
| 115 |
+
"flash_attention_varlen",
|
| 116 |
+
"flash_attn_varlen_func",
|
| 117 |
+
]
|