Kernels

[WIP] Add sliding-window attention support to the varlen kernel

#5
by ArthurZ HF Staff - opened

Status: WIP β€” authored without a Metal build/test environment. Do not merge before building and running the parity tests (see checklist).

Why

Sliding-window (SWA) models β€” Mistral/Ministral, Gemma 2/3, Qwen sliding layers, the gpt-oss family β€” pass window_size into flash_attn_varlen_func. Today the wrapper raises NotImplementedError("Window attention is not supported"), so on MPS Transformers falls back to eager attention, which materializes the full [B, H, n, n] score matrix.

Measured on an M-series GPU (fp16, openai/privacy-filter, eager): driver memory grows 3.1 β†’ 4.2 β†’ 18.5 β†’ 35.9 GiB across 512 β†’ 2048 β†’ 8192 β†’ 16384 tokens, then crashes (integer out of range β€” the score tensor exceeds int32 indexing at 16k). Accepting the window lets these models use the flash path, whose memory is flat in sequence length.

What

  • Plumb window_left / window_right through flash_attn_varlen_func β†’ flash_attention_varlen β†’ the flash_attention_varlen op β†’ host launcher β†’ AttnParams β†’ the shader (appended two int32 to both the host and device AttnParams, ABI matched).
  • New function constant has_window (302), gating a mask block that is a direct analog of the existing do_causal block: a key at signed distance d from the query row is kept iff d <= window_left (past) and -d <= window_right (future); -1 = unbounded on that side. Follows flash-attn window_size=(left, right) semantics.
  • flash_attn_varlen_func now forwards window_size instead of raising.

Scope / limitations

  • Correctness + the flat-memory win only. The K-block loop still visits every block (the band is applied as a mask, not by skipping out-of-band blocks). Block-skipping for a compute speedup is a natural follow-up.
  • Attention sinks (gpt-oss s_aux) are NOT in this PR. They are a denominator-only term in the online softmax: before the final Otile.row_bin_op<DivOp>(sum_score), add exp2(sink_h * log2(e) - max_score[i]) to sum_score[i] (rescaling the running max to include the sink for numerical safety). Happy to send that as a second PR.

Validate before merge

  • Build with kernel-builder for the torchNN-metal-aarch64-darwin targets.
  • tests/test_flash_attention.py + a new windowed case vs a reference banded-mask SDPA.
  • Numeric parity for a bidirectional window (e.g. Gemma/Mistral) and a causal window.

Built with kernel-builder and validated on Apple Silicon (M-series, MPS).

Now includes attention sinks (s_aux) alongside sliding-window, plus two fixes the parity test surfaced:

  1. band-limited K-loop (kb_start..kb_lim) + fast-forward the K/V block loaders to the first in-band block β€” this is both correctness and the perf win;
  2. guard the online-softmax against a fully-masked K-block (-inf - (-inf) β†’ NaN).

s_aux is a denominator-only term in the base-2 epilogue (re-folding the running max), matching gpt-oss. transformers auto-routes both window_size and s_aux by signature introspection β€” no transformers-side change needed.

Parity vs fp32 reference (cos / max-abs), fp16, MPS:

case cos max_abs
full attention 1.000000 0.0001
sliding window (64,64) 1.000000 0.0001
attention sink only 1.000000 0.0001
window + sink 1.000000 0.0001
+ GQA (H=8,Hkv=2) 1.000000 0.0001
H=14,Hkv=2,win=127,sink 1.000000 0.0001

Perf β€” single sequence, fp16, H=8 D=64, window=128 + sink (eager = full O(n^2) materialization):

tokens eager ms eager MiB this PR ms this PR MiB
8192 134.7 9277 2.6 49
16384 MPSGraph error 35869 5.9 1041
131072 β€” β€” 11.8 1041

Memory stays flat and it runs to 131k tokens where eager dies at ~16k.

Notes: the pinned kernel-builder build matrix tops out at torch 2.10; validated by loading the torch2.10 build under a torch 2.11 runtime. build.toml bumped to the current format (backends = ["metal"]) so it builds with the current kernel-builder. Built single-threaded β€” a parallel build hit a flaky configurePhase segfault unrelated to these sources.

End-to-end validated on openai/privacy-filter (MPS, M-series). Three fixes to make the kernel usable from the standard (non-continuous-batching) forward path:

  1. flash_attn_func non-varlen entry (wraps the varlen kernel via per-row cu_seqlens) so _flash_fn isn't None for the base forward path;
  2. force input contiguity β€” transformers hands the kernel a [B,S,H,D] view whose B=1 reshape keeps transposed strides; the kernel reads raw storage, so a non-contiguous input silently scrambles the output;
  3. cast sinks back to fp32 β€” the flash wrapper casts s_aux to the query dtype (fp16) but the kernel reads them as fp32.

With all three, eager-vs-kernel end-to-end logits: cos 0.9966, argmax-agree 100% (fp16, N=512).

Perf (single forward, fp16; eager = full O(n^2)):

tokens eager tok/s eager MiB this PR tok/s this PR MiB
8192 1413 18516 2246 3158
16384 OOM 35904 3734 4241
131072 β€” β€” 10546 11198

Runs to 131k tokens flat where eager OOMs at 16k. Pairs with the transformers PR adding metal-flash-sdpa to OpenAIPrivacyFilter._compatible_flash_implementations.

torch.compile compatibility. The host code opened its own compute encoder while the MPS stream may already have one open (inductor coalesces kernels under torch.compile), tripping failed assertion: A command encoder is already encoding. Fixed by calling MPSStream::endKernelCoalescing() before [cmdBuf computeCommandEncoder].

With this, openai/privacy-filter runs under torch.compile(model.forward) on MPS (cos 0.997 / argmax 100% vs non-compiled). The custom op still forces graph breaks around attention, but inductor optimizes the surrounding MoE/dense work:

tokens metal metal + compile
8192 1642 tok/s 2041 (+24%)
32768 5043 tok/s 5381 (+7%)

(Reusing the stream's existing encoder instead of opening a new one would also let the attention itself fuse into the compiled graph β€” a possible follow-up.)

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment