Instructions to use kernels-community/metal-flash-sdpa with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/metal-flash-sdpa with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/metal-flash-sdpa") - Notebooks
- Google Colab
- Kaggle
[WIP] Add sliding-window attention support to the varlen kernel
Status: WIP β authored without a Metal build/test environment. Do not merge before building and running the parity tests (see checklist).
Why
Sliding-window (SWA) models β Mistral/Ministral, Gemma 2/3, Qwen sliding layers, the gpt-oss family β pass window_size into flash_attn_varlen_func. Today the wrapper raises NotImplementedError("Window attention is not supported"), so on MPS Transformers falls back to eager attention, which materializes the full [B, H, n, n] score matrix.
Measured on an M-series GPU (fp16, openai/privacy-filter, eager): driver memory grows 3.1 β 4.2 β 18.5 β 35.9 GiB across 512 β 2048 β 8192 β 16384 tokens, then crashes (integer out of range β the score tensor exceeds int32 indexing at 16k). Accepting the window lets these models use the flash path, whose memory is flat in sequence length.
What
- Plumb
window_left/window_rightthroughflash_attn_varlen_funcβflash_attention_varlenβ theflash_attention_varlenop β host launcher βAttnParamsβ the shader (appended twoint32to both the host and deviceAttnParams, ABI matched). - New function constant
has_window(302), gating a mask block that is a direct analog of the existingdo_causalblock: a key at signed distancedfrom the query row is kept iffd <= window_left(past) and-d <= window_right(future);-1= unbounded on that side. Follows flash-attnwindow_size=(left, right)semantics. flash_attn_varlen_funcnow forwardswindow_sizeinstead of raising.
Scope / limitations
- Correctness + the flat-memory win only. The K-block loop still visits every block (the band is applied as a mask, not by skipping out-of-band blocks). Block-skipping for a compute speedup is a natural follow-up.
- Attention sinks (gpt-oss
s_aux) are NOT in this PR. They are a denominator-only term in the online softmax: before the finalOtile.row_bin_op<DivOp>(sum_score), addexp2(sink_h * log2(e) - max_score[i])tosum_score[i](rescaling the running max to include the sink for numerical safety). Happy to send that as a second PR.
Validate before merge
- Build with kernel-builder for the
torchNN-metal-aarch64-darwintargets. -
tests/test_flash_attention.py+ a new windowed case vs a reference banded-mask SDPA. - Numeric parity for a bidirectional window (e.g. Gemma/Mistral) and a causal window.
Built with kernel-builder and validated on Apple Silicon (M-series, MPS).
Now includes attention sinks (s_aux) alongside sliding-window, plus two fixes the parity test surfaced:
- band-limited K-loop (
kb_start..kb_lim) + fast-forward the K/V block loaders to the first in-band block β this is both correctness and the perf win; - guard the online-softmax against a fully-masked K-block (
-inf - (-inf)β NaN).
s_aux is a denominator-only term in the base-2 epilogue (re-folding the running max), matching gpt-oss. transformers auto-routes both window_size and s_aux by signature introspection β no transformers-side change needed.
Parity vs fp32 reference (cos / max-abs), fp16, MPS:
| case | cos | max_abs |
|---|---|---|
| full attention | 1.000000 | 0.0001 |
| sliding window (64,64) | 1.000000 | 0.0001 |
| attention sink only | 1.000000 | 0.0001 |
| window + sink | 1.000000 | 0.0001 |
| + GQA (H=8,Hkv=2) | 1.000000 | 0.0001 |
| H=14,Hkv=2,win=127,sink | 1.000000 | 0.0001 |
Perf β single sequence, fp16, H=8 D=64, window=128 + sink (eager = full O(n^2) materialization):
| tokens | eager ms | eager MiB | this PR ms | this PR MiB |
|---|---|---|---|---|
| 8192 | 134.7 | 9277 | 2.6 | 49 |
| 16384 | MPSGraph error | 35869 | 5.9 | 1041 |
| 131072 | β | β | 11.8 | 1041 |
Memory stays flat and it runs to 131k tokens where eager dies at ~16k.
Notes: the pinned kernel-builder build matrix tops out at torch 2.10; validated by loading the torch2.10 build under a torch 2.11 runtime. build.toml bumped to the current format (backends = ["metal"]) so it builds with the current kernel-builder. Built single-threaded β a parallel build hit a flaky configurePhase segfault unrelated to these sources.
End-to-end validated on openai/privacy-filter (MPS, M-series). Three fixes to make the kernel usable from the standard (non-continuous-batching) forward path:
flash_attn_funcnon-varlen entry (wraps the varlen kernel via per-rowcu_seqlens) so_flash_fnisn'tNonefor the base forward path;- force input contiguity β transformers hands the kernel a
[B,S,H,D]view whoseB=1reshape keeps transposed strides; the kernel reads raw storage, so a non-contiguous input silently scrambles the output; - cast
sinksback to fp32 β the flash wrapper castss_auxto the query dtype (fp16) but the kernel reads them as fp32.
With all three, eager-vs-kernel end-to-end logits: cos 0.9966, argmax-agree 100% (fp16, N=512).
Perf (single forward, fp16; eager = full O(n^2)):
| tokens | eager tok/s | eager MiB | this PR tok/s | this PR MiB |
|---|---|---|---|---|
| 8192 | 1413 | 18516 | 2246 | 3158 |
| 16384 | OOM | 35904 | 3734 | 4241 |
| 131072 | β | β | 10546 | 11198 |
Runs to 131k tokens flat where eager OOMs at 16k. Pairs with the transformers PR adding metal-flash-sdpa to OpenAIPrivacyFilter._compatible_flash_implementations.
torch.compile compatibility. The host code opened its own compute encoder while the MPS stream may already have one open (inductor coalesces kernels under torch.compile), tripping failed assertion: A command encoder is already encoding. Fixed by calling MPSStream::endKernelCoalescing() before [cmdBuf computeCommandEncoder].
With this, openai/privacy-filter runs under torch.compile(model.forward) on MPS (cos 0.997 / argmax 100% vs non-compiled). The custom op still forces graph breaks around attention, but inductor optimizes the surrounding MoE/dense work:
| tokens | metal | metal + compile |
|---|---|---|
| 8192 | 1642 tok/s | 2041 (+24%) |
| 32768 | 5043 tok/s | 5381 (+7%) |
(Reusing the stream's existing encoder instead of opening a new one would also let the attention itself fuse into the compiled graph β a possible follow-up.)