Instructions to use kernels-community/metal-flash-sdpa with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/metal-flash-sdpa with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/metal-flash-sdpa") - Notebooks
- Google Colab
- Kaggle
[WIP] Add sliding-window attention support to the varlen kernel
Browse files**Status: WIP — authored without a Metal build/test environment. Do not merge before building and running the parity tests (see checklist).**
### Why
Sliding-window (SWA) models — Mistral/Ministral, Gemma 2/3, Qwen sliding layers, the gpt-oss family — pass `window_size` into `flash_attn_varlen_func`. Today the wrapper raises `NotImplementedError("Window attention is not supported")`, so on MPS Transformers falls back to **eager** attention, which materializes the full `[B, H, n, n]` score matrix.
Measured on an M-series GPU (fp16, `openai/privacy-filter`, eager): driver memory grows 3.1 → 4.2 → 18.5 → 35.9 GiB across 512 → 2048 → 8192 → 16384 tokens, then crashes (`integer out of range` — the score tensor exceeds int32 indexing at 16k). Accepting the window lets these models use the flash path, whose memory is flat in sequence length.
### What
- Plumb `window_left` / `window_right` through `flash_attn_varlen_func` → `flash_attention_varlen` → the `flash_attention_varlen` op → host launcher → `AttnParams` → the shader (appended two `int32` to both the host and device `AttnParams`, ABI matched).
- New function constant `has_window` (302), gating a mask block that is a direct analog of the existing `do_causal` block: a key at signed distance `d` from the query row is kept iff `d <= window_left` (past) and `-d <= window_right` (future); `-1` = unbounded on that side. Follows flash-attn `window_size=(left, right)` semantics.
- `flash_attn_varlen_func` now forwards `window_size` instead of raising.
### Scope / limitations
- Correctness + the flat-memory win only. The K-block loop still visits every block (the band is applied as a mask, not by skipping out-of-band blocks). Block-skipping for a compute speedup is a natural follow-up.
- **Attention sinks (gpt-oss `s_aux`) are NOT in this PR.** They are a denominator-only term in the online softmax: before the final `Otile.row_bin_op<DivOp>(sum_score)`, add `exp2(sink_h * log2(e) - max_score[i])` to `sum_score[i]` (rescaling the running max to include the sink for numerical safety). Happy to send that as a second PR.
### Validate before merge
- [ ] Build with kernel-builder for the `torchNN-metal-aarch64-darwin` targets.
- [ ] `tests/test_flash_attention.py` + a new windowed case vs a reference banded-mask SDPA.
- [ ] Numeric parity for a bidirectional window (e.g. Gemma/Mistral) and a causal window.
|
@@ -1506,6 +1506,10 @@ struct AttnParams {
|
|
| 1506 |
int total_k_tokens; ///< Total number of key/value tokens
|
| 1507 |
int max_seqlen_q; ///< Maximum query sequence length
|
| 1508 |
int max_seqlen_k; ///< Maximum key/value sequence length
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1509 |
};
|
| 1510 |
|
| 1511 |
struct AttnMaskParams {
|
|
@@ -1521,6 +1525,7 @@ constant bool align_K [[function_constant(201)]];
|
|
| 1521 |
|
| 1522 |
constant bool has_mask [[function_constant(300)]];
|
| 1523 |
constant bool do_causal [[function_constant(301)]];
|
|
|
|
| 1524 |
|
| 1525 |
template <typename T>
|
| 1526 |
struct TransformScale {
|
|
@@ -1894,6 +1899,38 @@ template <
|
|
| 1894 |
}
|
| 1895 |
}
|
| 1896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1897 |
// Other masking as needed
|
| 1898 |
if (has_mask) {
|
| 1899 |
using stile_t = decltype(Stile);
|
|
|
|
| 1506 |
int total_k_tokens; ///< Total number of key/value tokens
|
| 1507 |
int max_seqlen_q; ///< Maximum query sequence length
|
| 1508 |
int max_seqlen_k; ///< Maximum key/value sequence length
|
| 1509 |
+
|
| 1510 |
+
// Sliding-window attention support (-1 on a side = unbounded on that side)
|
| 1511 |
+
int window_left; ///< Max distance into the past a query may attend
|
| 1512 |
+
int window_right; ///< Max distance into the future a query may attend
|
| 1513 |
};
|
| 1514 |
|
| 1515 |
struct AttnMaskParams {
|
|
|
|
| 1525 |
|
| 1526 |
constant bool has_mask [[function_constant(300)]];
|
| 1527 |
constant bool do_causal [[function_constant(301)]];
|
| 1528 |
+
constant bool has_window [[function_constant(302)]];
|
| 1529 |
|
| 1530 |
template <typename T>
|
| 1531 |
struct TransformScale {
|
|
|
|
| 1899 |
}
|
| 1900 |
}
|
| 1901 |
|
| 1902 |
+
// Mask out keys outside the sliding window band [row - window_left, row + window_right]
|
| 1903 |
+
if (has_window) {
|
| 1904 |
+
using stile_t = decltype(Stile);
|
| 1905 |
+
using selem_t = typename stile_t::elem_type;
|
| 1906 |
+
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
| 1907 |
+
|
| 1908 |
+
const int wl = params->window_left; // -1 => unbounded into the past
|
| 1909 |
+
const int wr = params->window_right; // -1 => unbounded into the future
|
| 1910 |
+
|
| 1911 |
+
STEEL_PRAGMA_UNROLL
|
| 1912 |
+
for (short i = 0; i < stile_t::kTileRows; i++) {
|
| 1913 |
+
// Same row-position machinery as the causal block above.
|
| 1914 |
+
int row_pos = block_idx * BQ + tm + sm + (i * stile_t::kFragRows);
|
| 1915 |
+
if (q_seq_len < k_seq_len) {
|
| 1916 |
+
row_pos += (k_seq_len - q_seq_len);
|
| 1917 |
+
}
|
| 1918 |
+
STEEL_PRAGMA_UNROLL
|
| 1919 |
+
for (short j = 0; j < stile_t::kTileCols; j++) {
|
| 1920 |
+
const int col_pos_in_seq = kb * BK + sn + (j * stile_t::kFragCols);
|
| 1921 |
+
STEEL_PRAGMA_UNROLL
|
| 1922 |
+
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
| 1923 |
+
const int col = col_pos_in_seq + jj;
|
| 1924 |
+
const bool past_ok = (wl < 0) || ((row_pos - col) <= wl);
|
| 1925 |
+
const bool future_ok = (wr < 0) || ((col - row_pos) <= wr);
|
| 1926 |
+
if (!(past_ok && future_ok)) {
|
| 1927 |
+
Stile.frag_at(i, j)[jj] = neg_inf;
|
| 1928 |
+
}
|
| 1929 |
+
}
|
| 1930 |
+
}
|
| 1931 |
+
}
|
| 1932 |
+
}
|
| 1933 |
+
|
| 1934 |
// Other masking as needed
|
| 1935 |
if (has_mask) {
|
| 1936 |
using stile_t = decltype(Stile);
|
|
@@ -69,6 +69,8 @@ struct AttnParams {
|
|
| 69 |
int32_t total_k_tokens; // Total number of key/value tokens
|
| 70 |
int32_t max_seqlen_q; // Maximum query sequence length
|
| 71 |
int32_t max_seqlen_k; // Maximum key/value sequence length
|
|
|
|
|
|
|
| 72 |
};
|
| 73 |
|
| 74 |
// Forward declarations for kernel implementations
|
|
@@ -86,7 +88,9 @@ void call_flash_attention_varlen(
|
|
| 86 |
int64_t max_seqlen_k,
|
| 87 |
bool do_causal,
|
| 88 |
double scale,
|
| 89 |
-
double softcapping
|
|
|
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
void flash_attention_varlen(
|
|
@@ -100,7 +104,9 @@ void flash_attention_varlen(
|
|
| 100 |
int64_t max_seqlen_k, // Maximum key sequence length
|
| 101 |
bool do_causal, // Whether to use causal mask
|
| 102 |
double scale, // Attention scale
|
| 103 |
-
double softcapping
|
|
|
|
|
|
|
| 104 |
|
| 105 |
try {
|
| 106 |
// Get device and stream
|
|
@@ -142,9 +148,9 @@ void flash_attention_varlen(
|
|
| 142 |
// For variable-length Flash Attention, always use the full attention kernel
|
| 143 |
|
| 144 |
// Call the Flash Attention kernel
|
| 145 |
-
call_flash_attention_varlen(device, cmdBuf, lib, out, query, key, value,
|
| 146 |
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
| 147 |
-
do_causal, scale, softcapping);
|
| 148 |
} catch (const std::exception& e) {
|
| 149 |
throw;
|
| 150 |
} catch (...) {
|
|
@@ -167,7 +173,9 @@ void call_flash_attention_varlen(
|
|
| 167 |
int64_t max_seqlen_k,
|
| 168 |
bool do_causal,
|
| 169 |
double scale,
|
| 170 |
-
double softcapping
|
|
|
|
|
|
|
| 171 |
|
| 172 |
// Get dimensions
|
| 173 |
int64_t total_q_tokens = query.size(0);
|
|
@@ -197,7 +205,9 @@ void call_flash_attention_varlen(
|
|
| 197 |
params.total_k_tokens = key.size(0);
|
| 198 |
params.max_seqlen_q = max_seqlen_q;
|
| 199 |
params.max_seqlen_k = max_seqlen_k;
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
// Initialize fields that might be checked but aren't used in Flash Attention
|
| 202 |
params.qL = 0; // Not used in variable-length attention
|
| 203 |
params.kL = 0; // Not used in variable-length attention
|
|
@@ -227,11 +237,13 @@ void call_flash_attention_varlen(
|
|
| 227 |
// The kernel will handle the cu_seqlens internally
|
| 228 |
|
| 229 |
bool has_mask = false; // Masks are not supported in Flash Attention
|
|
|
|
| 230 |
|
| 231 |
// Setup function constants
|
| 232 |
MTLFunctionConstantValues *constants = [MTLFunctionConstantValues new];
|
| 233 |
[constants setConstantValue:&has_mask type:MTLDataTypeBool atIndex:300];
|
| 234 |
[constants setConstantValue:&do_causal type:MTLDataTypeBool atIndex:301];
|
|
|
|
| 235 |
|
| 236 |
// Construct kernel name based on data type and head dimension
|
| 237 |
std::string kernel_name = "steel_attention_";
|
|
|
|
| 69 |
int32_t total_k_tokens; // Total number of key/value tokens
|
| 70 |
int32_t max_seqlen_q; // Maximum query sequence length
|
| 71 |
int32_t max_seqlen_k; // Maximum key/value sequence length
|
| 72 |
+
int32_t window_left; // Sliding window: max distance into the past (-1 = unbounded)
|
| 73 |
+
int32_t window_right; // Sliding window: max distance into the future (-1 = unbounded)
|
| 74 |
};
|
| 75 |
|
| 76 |
// Forward declarations for kernel implementations
|
|
|
|
| 88 |
int64_t max_seqlen_k,
|
| 89 |
bool do_causal,
|
| 90 |
double scale,
|
| 91 |
+
double softcapping,
|
| 92 |
+
int64_t window_left,
|
| 93 |
+
int64_t window_right);
|
| 94 |
|
| 95 |
|
| 96 |
void flash_attention_varlen(
|
|
|
|
| 104 |
int64_t max_seqlen_k, // Maximum key sequence length
|
| 105 |
bool do_causal, // Whether to use causal mask
|
| 106 |
double scale, // Attention scale
|
| 107 |
+
double softcapping, // Softcapping value
|
| 108 |
+
int64_t window_left, // Sliding window past extent (-1 = unbounded)
|
| 109 |
+
int64_t window_right) { // Sliding window future extent (-1 = unbounded)
|
| 110 |
|
| 111 |
try {
|
| 112 |
// Get device and stream
|
|
|
|
| 148 |
// For variable-length Flash Attention, always use the full attention kernel
|
| 149 |
|
| 150 |
// Call the Flash Attention kernel
|
| 151 |
+
call_flash_attention_varlen(device, cmdBuf, lib, out, query, key, value,
|
| 152 |
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
| 153 |
+
do_causal, scale, softcapping, window_left, window_right);
|
| 154 |
} catch (const std::exception& e) {
|
| 155 |
throw;
|
| 156 |
} catch (...) {
|
|
|
|
| 173 |
int64_t max_seqlen_k,
|
| 174 |
bool do_causal,
|
| 175 |
double scale,
|
| 176 |
+
double softcapping,
|
| 177 |
+
int64_t window_left,
|
| 178 |
+
int64_t window_right) {
|
| 179 |
|
| 180 |
// Get dimensions
|
| 181 |
int64_t total_q_tokens = query.size(0);
|
|
|
|
| 205 |
params.total_k_tokens = key.size(0);
|
| 206 |
params.max_seqlen_q = max_seqlen_q;
|
| 207 |
params.max_seqlen_k = max_seqlen_k;
|
| 208 |
+
params.window_left = static_cast<int32_t>(window_left);
|
| 209 |
+
params.window_right = static_cast<int32_t>(window_right);
|
| 210 |
+
|
| 211 |
// Initialize fields that might be checked but aren't used in Flash Attention
|
| 212 |
params.qL = 0; // Not used in variable-length attention
|
| 213 |
params.kL = 0; // Not used in variable-length attention
|
|
|
|
| 237 |
// The kernel will handle the cu_seqlens internally
|
| 238 |
|
| 239 |
bool has_mask = false; // Masks are not supported in Flash Attention
|
| 240 |
+
bool has_window = (window_left >= 0) || (window_right >= 0);
|
| 241 |
|
| 242 |
// Setup function constants
|
| 243 |
MTLFunctionConstantValues *constants = [MTLFunctionConstantValues new];
|
| 244 |
[constants setConstantValue:&has_mask type:MTLDataTypeBool atIndex:300];
|
| 245 |
[constants setConstantValue:&do_causal type:MTLDataTypeBool atIndex:301];
|
| 246 |
+
[constants setConstantValue:&has_window type:MTLDataTypeBool atIndex:302];
|
| 247 |
|
| 248 |
// Construct kernel name based on data type and head dimension
|
| 249 |
std::string kernel_name = "steel_attention_";
|
|
@@ -17,6 +17,8 @@ def flash_attention_varlen(
|
|
| 17 |
do_causal: bool = False,
|
| 18 |
scale: Optional[float] = None,
|
| 19 |
softcapping: float = 1.0,
|
|
|
|
|
|
|
| 20 |
) -> None:
|
| 21 |
"""
|
| 22 |
Flash Attention with variable-length sequences.
|
|
@@ -38,10 +40,11 @@ def flash_attention_varlen(
|
|
| 38 |
- cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
|
| 39 |
- Supported head dimensions: 32, 64, 72, 80, 96, 128
|
| 40 |
- Masks are not supported
|
|
|
|
| 41 |
"""
|
| 42 |
if scale is None:
|
| 43 |
scale = query.shape[-1] ** -0.5
|
| 44 |
-
|
| 45 |
ops.flash_attention_varlen(
|
| 46 |
out,
|
| 47 |
query,
|
|
@@ -54,6 +57,8 @@ def flash_attention_varlen(
|
|
| 54 |
do_causal,
|
| 55 |
scale,
|
| 56 |
softcapping,
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
def flash_attn_varlen_func(
|
|
@@ -77,14 +82,14 @@ def flash_attn_varlen_func(
|
|
| 77 |
|
| 78 |
Note: This implementation does not support:
|
| 79 |
- dropout
|
| 80 |
-
- window attention
|
| 81 |
- alibi slopes
|
| 82 |
- returning attention probabilities
|
|
|
|
|
|
|
|
|
|
| 83 |
"""
|
| 84 |
if dropout_p > 0:
|
| 85 |
raise NotImplementedError("Dropout is not supported in this implementation")
|
| 86 |
-
if window_size != (-1, -1):
|
| 87 |
-
raise NotImplementedError("Window attention is not supported")
|
| 88 |
if alibi_slopes is not None:
|
| 89 |
raise NotImplementedError("ALiBi is not supported")
|
| 90 |
if return_attn_probs:
|
|
@@ -106,8 +111,10 @@ def flash_attn_varlen_func(
|
|
| 106 |
do_causal=causal,
|
| 107 |
scale=softmax_scale,
|
| 108 |
softcapping=1.0,
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
-
|
| 111 |
return out
|
| 112 |
|
| 113 |
|
|
|
|
| 17 |
do_causal: bool = False,
|
| 18 |
scale: Optional[float] = None,
|
| 19 |
softcapping: float = 1.0,
|
| 20 |
+
window_left: int = -1,
|
| 21 |
+
window_right: int = -1,
|
| 22 |
) -> None:
|
| 23 |
"""
|
| 24 |
Flash Attention with variable-length sequences.
|
|
|
|
| 40 |
- cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
|
| 41 |
- Supported head dimensions: 32, 64, 72, 80, 96, 128
|
| 42 |
- Masks are not supported
|
| 43 |
+
- window_left / window_right bound a sliding-window band (-1 = unbounded)
|
| 44 |
"""
|
| 45 |
if scale is None:
|
| 46 |
scale = query.shape[-1] ** -0.5
|
| 47 |
+
|
| 48 |
ops.flash_attention_varlen(
|
| 49 |
out,
|
| 50 |
query,
|
|
|
|
| 57 |
do_causal,
|
| 58 |
scale,
|
| 59 |
softcapping,
|
| 60 |
+
window_left,
|
| 61 |
+
window_right,
|
| 62 |
)
|
| 63 |
|
| 64 |
def flash_attn_varlen_func(
|
|
|
|
| 82 |
|
| 83 |
Note: This implementation does not support:
|
| 84 |
- dropout
|
|
|
|
| 85 |
- alibi slopes
|
| 86 |
- returning attention probabilities
|
| 87 |
+
|
| 88 |
+
`window_size = (left, right)` follows the flash-attn convention: a token attends to
|
| 89 |
+
keys in [pos - left, pos + right]; -1 means unbounded on that side.
|
| 90 |
"""
|
| 91 |
if dropout_p > 0:
|
| 92 |
raise NotImplementedError("Dropout is not supported in this implementation")
|
|
|
|
|
|
|
| 93 |
if alibi_slopes is not None:
|
| 94 |
raise NotImplementedError("ALiBi is not supported")
|
| 95 |
if return_attn_probs:
|
|
|
|
| 111 |
do_causal=causal,
|
| 112 |
scale=softmax_scale,
|
| 113 |
softcapping=1.0,
|
| 114 |
+
window_left=window_size[0],
|
| 115 |
+
window_right=window_size[1],
|
| 116 |
)
|
| 117 |
+
|
| 118 |
return out
|
| 119 |
|
| 120 |
|
|
@@ -4,7 +4,7 @@
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
-
ops.def("flash_attention_varlen(Tensor! out, Tensor query, Tensor key, Tensor value, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, bool do_causal, float scale, float softcapping) -> ()");
|
| 8 |
ops.impl("flash_attention_varlen", torch::kMPS, flash_attention_varlen);
|
| 9 |
}
|
| 10 |
|
|
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("flash_attention_varlen(Tensor! out, Tensor query, Tensor key, Tensor value, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, bool do_causal, float scale, float softcapping, int window_left, int window_right) -> ()");
|
| 8 |
ops.impl("flash_attention_varlen", torch::kMPS, flash_attention_varlen);
|
| 9 |
}
|
| 10 |
|
|
@@ -13,4 +13,6 @@ void flash_attention_varlen(
|
|
| 13 |
int64_t max_seqlen_k,
|
| 14 |
bool do_causal,
|
| 15 |
double scale,
|
| 16 |
-
double softcapping
|
|
|
|
|
|
|
|
|
| 13 |
int64_t max_seqlen_k,
|
| 14 |
bool do_causal,
|
| 15 |
double scale,
|
| 16 |
+
double softcapping,
|
| 17 |
+
int64_t window_left,
|
| 18 |
+
int64_t window_right);
|