Kernels
ArthurZ HF Staff commited on
Commit
87b9857
·
verified ·
1 Parent(s): ffdd1a5

[WIP] Add sliding-window attention support to the varlen kernel

Browse files

**Status: WIP — authored without a Metal build/test environment. Do not merge before building and running the parity tests (see checklist).**

### Why
Sliding-window (SWA) models — Mistral/Ministral, Gemma 2/3, Qwen sliding layers, the gpt-oss family — pass `window_size` into `flash_attn_varlen_func`. Today the wrapper raises `NotImplementedError("Window attention is not supported")`, so on MPS Transformers falls back to **eager** attention, which materializes the full `[B, H, n, n]` score matrix.

Measured on an M-series GPU (fp16, `openai/privacy-filter`, eager): driver memory grows 3.1 → 4.2 → 18.5 → 35.9 GiB across 512 → 2048 → 8192 → 16384 tokens, then crashes (`integer out of range` — the score tensor exceeds int32 indexing at 16k). Accepting the window lets these models use the flash path, whose memory is flat in sequence length.

### What
- Plumb `window_left` / `window_right` through `flash_attn_varlen_func` → `flash_attention_varlen` → the `flash_attention_varlen` op → host launcher → `AttnParams` → the shader (appended two `int32` to both the host and device `AttnParams`, ABI matched).
- New function constant `has_window` (302), gating a mask block that is a direct analog of the existing `do_causal` block: a key at signed distance `d` from the query row is kept iff `d <= window_left` (past) and `-d <= window_right` (future); `-1` = unbounded on that side. Follows flash-attn `window_size=(left, right)` semantics.
- `flash_attn_varlen_func` now forwards `window_size` instead of raising.

### Scope / limitations
- Correctness + the flat-memory win only. The K-block loop still visits every block (the band is applied as a mask, not by skipping out-of-band blocks). Block-skipping for a compute speedup is a natural follow-up.
- **Attention sinks (gpt-oss `s_aux`) are NOT in this PR.** They are a denominator-only term in the online softmax: before the final `Otile.row_bin_op<DivOp>(sum_score)`, add `exp2(sink_h * log2(e) - max_score[i])` to `sum_score[i]` (rescaling the running max to include the sink for numerical safety). Happy to send that as a second PR.

### Validate before merge
- [ ] Build with kernel-builder for the `torchNN-metal-aarch64-darwin` targets.
- [ ] `tests/test_flash_attention.py` + a new windowed case vs a reference banded-mask SDPA.
- [ ] Numeric parity for a bidirectional window (e.g. Gemma/Mistral) and a causal window.

sdpa-metal/scaled_dot_product_attention.metal CHANGED
@@ -1506,6 +1506,10 @@ struct AttnParams {
1506
  int total_k_tokens; ///< Total number of key/value tokens
1507
  int max_seqlen_q; ///< Maximum query sequence length
1508
  int max_seqlen_k; ///< Maximum key/value sequence length
 
 
 
 
1509
  };
1510
 
1511
  struct AttnMaskParams {
@@ -1521,6 +1525,7 @@ constant bool align_K [[function_constant(201)]];
1521
 
1522
  constant bool has_mask [[function_constant(300)]];
1523
  constant bool do_causal [[function_constant(301)]];
 
1524
 
1525
  template <typename T>
1526
  struct TransformScale {
@@ -1894,6 +1899,38 @@ template <
1894
  }
1895
  }
1896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1897
  // Other masking as needed
1898
  if (has_mask) {
1899
  using stile_t = decltype(Stile);
 
1506
  int total_k_tokens; ///< Total number of key/value tokens
1507
  int max_seqlen_q; ///< Maximum query sequence length
1508
  int max_seqlen_k; ///< Maximum key/value sequence length
1509
+
1510
+ // Sliding-window attention support (-1 on a side = unbounded on that side)
1511
+ int window_left; ///< Max distance into the past a query may attend
1512
+ int window_right; ///< Max distance into the future a query may attend
1513
  };
1514
 
1515
  struct AttnMaskParams {
 
1525
 
1526
  constant bool has_mask [[function_constant(300)]];
1527
  constant bool do_causal [[function_constant(301)]];
1528
+ constant bool has_window [[function_constant(302)]];
1529
 
1530
  template <typename T>
1531
  struct TransformScale {
 
1899
  }
1900
  }
1901
 
1902
+ // Mask out keys outside the sliding window band [row - window_left, row + window_right]
1903
+ if (has_window) {
1904
+ using stile_t = decltype(Stile);
1905
+ using selem_t = typename stile_t::elem_type;
1906
+ constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
1907
+
1908
+ const int wl = params->window_left; // -1 => unbounded into the past
1909
+ const int wr = params->window_right; // -1 => unbounded into the future
1910
+
1911
+ STEEL_PRAGMA_UNROLL
1912
+ for (short i = 0; i < stile_t::kTileRows; i++) {
1913
+ // Same row-position machinery as the causal block above.
1914
+ int row_pos = block_idx * BQ + tm + sm + (i * stile_t::kFragRows);
1915
+ if (q_seq_len < k_seq_len) {
1916
+ row_pos += (k_seq_len - q_seq_len);
1917
+ }
1918
+ STEEL_PRAGMA_UNROLL
1919
+ for (short j = 0; j < stile_t::kTileCols; j++) {
1920
+ const int col_pos_in_seq = kb * BK + sn + (j * stile_t::kFragCols);
1921
+ STEEL_PRAGMA_UNROLL
1922
+ for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
1923
+ const int col = col_pos_in_seq + jj;
1924
+ const bool past_ok = (wl < 0) || ((row_pos - col) <= wl);
1925
+ const bool future_ok = (wr < 0) || ((col - row_pos) <= wr);
1926
+ if (!(past_ok && future_ok)) {
1927
+ Stile.frag_at(i, j)[jj] = neg_inf;
1928
+ }
1929
+ }
1930
+ }
1931
+ }
1932
+ }
1933
+
1934
  // Other masking as needed
1935
  if (has_mask) {
1936
  using stile_t = decltype(Stile);
sdpa-metal/scaled_dot_product_attention.mm CHANGED
@@ -69,6 +69,8 @@ struct AttnParams {
69
  int32_t total_k_tokens; // Total number of key/value tokens
70
  int32_t max_seqlen_q; // Maximum query sequence length
71
  int32_t max_seqlen_k; // Maximum key/value sequence length
 
 
72
  };
73
 
74
  // Forward declarations for kernel implementations
@@ -86,7 +88,9 @@ void call_flash_attention_varlen(
86
  int64_t max_seqlen_k,
87
  bool do_causal,
88
  double scale,
89
- double softcapping);
 
 
90
 
91
 
92
  void flash_attention_varlen(
@@ -100,7 +104,9 @@ void flash_attention_varlen(
100
  int64_t max_seqlen_k, // Maximum key sequence length
101
  bool do_causal, // Whether to use causal mask
102
  double scale, // Attention scale
103
- double softcapping) { // Softcapping value
 
 
104
 
105
  try {
106
  // Get device and stream
@@ -142,9 +148,9 @@ void flash_attention_varlen(
142
  // For variable-length Flash Attention, always use the full attention kernel
143
 
144
  // Call the Flash Attention kernel
145
- call_flash_attention_varlen(device, cmdBuf, lib, out, query, key, value,
146
  cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
147
- do_causal, scale, softcapping);
148
  } catch (const std::exception& e) {
149
  throw;
150
  } catch (...) {
@@ -167,7 +173,9 @@ void call_flash_attention_varlen(
167
  int64_t max_seqlen_k,
168
  bool do_causal,
169
  double scale,
170
- double softcapping) {
 
 
171
 
172
  // Get dimensions
173
  int64_t total_q_tokens = query.size(0);
@@ -197,7 +205,9 @@ void call_flash_attention_varlen(
197
  params.total_k_tokens = key.size(0);
198
  params.max_seqlen_q = max_seqlen_q;
199
  params.max_seqlen_k = max_seqlen_k;
200
-
 
 
201
  // Initialize fields that might be checked but aren't used in Flash Attention
202
  params.qL = 0; // Not used in variable-length attention
203
  params.kL = 0; // Not used in variable-length attention
@@ -227,11 +237,13 @@ void call_flash_attention_varlen(
227
  // The kernel will handle the cu_seqlens internally
228
 
229
  bool has_mask = false; // Masks are not supported in Flash Attention
 
230
 
231
  // Setup function constants
232
  MTLFunctionConstantValues *constants = [MTLFunctionConstantValues new];
233
  [constants setConstantValue:&has_mask type:MTLDataTypeBool atIndex:300];
234
  [constants setConstantValue:&do_causal type:MTLDataTypeBool atIndex:301];
 
235
 
236
  // Construct kernel name based on data type and head dimension
237
  std::string kernel_name = "steel_attention_";
 
69
  int32_t total_k_tokens; // Total number of key/value tokens
70
  int32_t max_seqlen_q; // Maximum query sequence length
71
  int32_t max_seqlen_k; // Maximum key/value sequence length
72
+ int32_t window_left; // Sliding window: max distance into the past (-1 = unbounded)
73
+ int32_t window_right; // Sliding window: max distance into the future (-1 = unbounded)
74
  };
75
 
76
  // Forward declarations for kernel implementations
 
88
  int64_t max_seqlen_k,
89
  bool do_causal,
90
  double scale,
91
+ double softcapping,
92
+ int64_t window_left,
93
+ int64_t window_right);
94
 
95
 
96
  void flash_attention_varlen(
 
104
  int64_t max_seqlen_k, // Maximum key sequence length
105
  bool do_causal, // Whether to use causal mask
106
  double scale, // Attention scale
107
+ double softcapping, // Softcapping value
108
+ int64_t window_left, // Sliding window past extent (-1 = unbounded)
109
+ int64_t window_right) { // Sliding window future extent (-1 = unbounded)
110
 
111
  try {
112
  // Get device and stream
 
148
  // For variable-length Flash Attention, always use the full attention kernel
149
 
150
  // Call the Flash Attention kernel
151
+ call_flash_attention_varlen(device, cmdBuf, lib, out, query, key, value,
152
  cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
153
+ do_causal, scale, softcapping, window_left, window_right);
154
  } catch (const std::exception& e) {
155
  throw;
156
  } catch (...) {
 
173
  int64_t max_seqlen_k,
174
  bool do_causal,
175
  double scale,
176
+ double softcapping,
177
+ int64_t window_left,
178
+ int64_t window_right) {
179
 
180
  // Get dimensions
181
  int64_t total_q_tokens = query.size(0);
 
205
  params.total_k_tokens = key.size(0);
206
  params.max_seqlen_q = max_seqlen_q;
207
  params.max_seqlen_k = max_seqlen_k;
208
+ params.window_left = static_cast<int32_t>(window_left);
209
+ params.window_right = static_cast<int32_t>(window_right);
210
+
211
  // Initialize fields that might be checked but aren't used in Flash Attention
212
  params.qL = 0; // Not used in variable-length attention
213
  params.kL = 0; // Not used in variable-length attention
 
237
  // The kernel will handle the cu_seqlens internally
238
 
239
  bool has_mask = false; // Masks are not supported in Flash Attention
240
+ bool has_window = (window_left >= 0) || (window_right >= 0);
241
 
242
  // Setup function constants
243
  MTLFunctionConstantValues *constants = [MTLFunctionConstantValues new];
244
  [constants setConstantValue:&has_mask type:MTLDataTypeBool atIndex:300];
245
  [constants setConstantValue:&do_causal type:MTLDataTypeBool atIndex:301];
246
+ [constants setConstantValue:&has_window type:MTLDataTypeBool atIndex:302];
247
 
248
  // Construct kernel name based on data type and head dimension
249
  std::string kernel_name = "steel_attention_";
torch-ext/metal_flash_sdpa/_custom_ops.py CHANGED
@@ -17,6 +17,8 @@ def flash_attention_varlen(
17
  do_causal: bool = False,
18
  scale: Optional[float] = None,
19
  softcapping: float = 1.0,
 
 
20
  ) -> None:
21
  """
22
  Flash Attention with variable-length sequences.
@@ -38,10 +40,11 @@ def flash_attention_varlen(
38
  - cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
39
  - Supported head dimensions: 32, 64, 72, 80, 96, 128
40
  - Masks are not supported
 
41
  """
42
  if scale is None:
43
  scale = query.shape[-1] ** -0.5
44
-
45
  ops.flash_attention_varlen(
46
  out,
47
  query,
@@ -54,6 +57,8 @@ def flash_attention_varlen(
54
  do_causal,
55
  scale,
56
  softcapping,
 
 
57
  )
58
 
59
  def flash_attn_varlen_func(
@@ -77,14 +82,14 @@ def flash_attn_varlen_func(
77
 
78
  Note: This implementation does not support:
79
  - dropout
80
- - window attention
81
  - alibi slopes
82
  - returning attention probabilities
 
 
 
83
  """
84
  if dropout_p > 0:
85
  raise NotImplementedError("Dropout is not supported in this implementation")
86
- if window_size != (-1, -1):
87
- raise NotImplementedError("Window attention is not supported")
88
  if alibi_slopes is not None:
89
  raise NotImplementedError("ALiBi is not supported")
90
  if return_attn_probs:
@@ -106,8 +111,10 @@ def flash_attn_varlen_func(
106
  do_causal=causal,
107
  scale=softmax_scale,
108
  softcapping=1.0,
 
 
109
  )
110
-
111
  return out
112
 
113
 
 
17
  do_causal: bool = False,
18
  scale: Optional[float] = None,
19
  softcapping: float = 1.0,
20
+ window_left: int = -1,
21
+ window_right: int = -1,
22
  ) -> None:
23
  """
24
  Flash Attention with variable-length sequences.
 
40
  - cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
41
  - Supported head dimensions: 32, 64, 72, 80, 96, 128
42
  - Masks are not supported
43
+ - window_left / window_right bound a sliding-window band (-1 = unbounded)
44
  """
45
  if scale is None:
46
  scale = query.shape[-1] ** -0.5
47
+
48
  ops.flash_attention_varlen(
49
  out,
50
  query,
 
57
  do_causal,
58
  scale,
59
  softcapping,
60
+ window_left,
61
+ window_right,
62
  )
63
 
64
  def flash_attn_varlen_func(
 
82
 
83
  Note: This implementation does not support:
84
  - dropout
 
85
  - alibi slopes
86
  - returning attention probabilities
87
+
88
+ `window_size = (left, right)` follows the flash-attn convention: a token attends to
89
+ keys in [pos - left, pos + right]; -1 means unbounded on that side.
90
  """
91
  if dropout_p > 0:
92
  raise NotImplementedError("Dropout is not supported in this implementation")
 
 
93
  if alibi_slopes is not None:
94
  raise NotImplementedError("ALiBi is not supported")
95
  if return_attn_probs:
 
111
  do_causal=causal,
112
  scale=softmax_scale,
113
  softcapping=1.0,
114
+ window_left=window_size[0],
115
+ window_right=window_size[1],
116
  )
117
+
118
  return out
119
 
120
 
torch-ext/torch_binding.cpp CHANGED
@@ -4,7 +4,7 @@
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- ops.def("flash_attention_varlen(Tensor! out, Tensor query, Tensor key, Tensor value, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, bool do_causal, float scale, float softcapping) -> ()");
8
  ops.impl("flash_attention_varlen", torch::kMPS, flash_attention_varlen);
9
  }
10
 
 
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("flash_attention_varlen(Tensor! out, Tensor query, Tensor key, Tensor value, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, bool do_causal, float scale, float softcapping, int window_left, int window_right) -> ()");
8
  ops.impl("flash_attention_varlen", torch::kMPS, flash_attention_varlen);
9
  }
10
 
torch-ext/torch_binding.h CHANGED
@@ -13,4 +13,6 @@ void flash_attention_varlen(
13
  int64_t max_seqlen_k,
14
  bool do_causal,
15
  double scale,
16
- double softcapping);
 
 
 
13
  int64_t max_seqlen_k,
14
  bool do_causal,
15
  double scale,
16
+ double softcapping,
17
+ int64_t window_left,
18
+ int64_t window_right);