feat: opt autotuner#1373
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces autotuning support for flash_attn_with_kvcache using a custom autotuner wrapper, integrates boundary synchronization to minimize tuning interference, and updates the Qwen2-VL model to handle text-only batches during decode. Feedback on these changes highlights potential issues in the autotuning utility: a possible out-of-bounds indexing error when benchmarking short sequence lengths, a potential TypeError if window_size is not a subscriptable sequence, and key collisions in the run key generator for ultra-long context lengths.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _flash_attn_decode_bench_kv_lens(page_table): | ||
| if page_table is None or page_table.dim() < 2: | ||
| return [] | ||
|
|
||
| max_kv_len = int(page_table.shape[1]) | ||
| if max_kv_len <= 0: | ||
| return [] | ||
|
|
||
| return [10240, max_kv_len] |
There was a problem hiding this comment.
In _flash_attn_decode_bench_kv_lens, returning a hardcoded benchmark sequence length of 10240 when max_kv_len is smaller than 10240 will cause out-of-bounds indexing on the page_table tensor during autotuning. This can lead to illegal memory accesses or GPU page faults. We should cap the benchmark sequence length to max_kv_len if it is smaller than 10240.
| def _flash_attn_decode_bench_kv_lens(page_table): | |
| if page_table is None or page_table.dim() < 2: | |
| return [] | |
| max_kv_len = int(page_table.shape[1]) | |
| if max_kv_len <= 0: | |
| return [] | |
| return [10240, max_kv_len] | |
| def _flash_attn_decode_bench_kv_lens(page_table): | |
| if page_table is None or page_table.dim() < 2: | |
| return [] | |
| max_kv_len = int(page_table.shape[1]) | |
| if max_kv_len <= 0: | |
| return [] | |
| if max_kv_len < 10240: | |
| return [max_kv_len] | |
| return [10240, max_kv_len] |
| def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks): | ||
| return { | ||
| "qd": str(q.dtype), | ||
| "kd": str(k_cache.dtype), | ||
| "vd": str(v_cache.dtype), | ||
| "qh": int(q.shape[-2]), | ||
| "kh": int(k_cache.shape[-2]), | ||
| "hd": int(q.shape[-1]), | ||
| "vh": int(v_cache.shape[-1]), | ||
| "pb": int(k_cache.shape[-3]), | ||
| "c": int(bool(causal)), | ||
| "wl": int(window_size[0]), | ||
| "wr": int(window_size[1]), | ||
| "sc": int(softcap > 0.0), | ||
| "sk": int(sinks is not None), | ||
| "sgl": getattr(sgl_ops, "__version__", "unknown"), | ||
| } |
There was a problem hiding this comment.
In _flash_attn_kvcache_static_key, window_size is assumed to be a subscriptable sequence of at least 2 elements. However, in some contexts or configurations, window_size can be passed as an integer (for symmetric window size) or None. Accessing window_size[0] directly will raise a TypeError in those cases. We should handle None and integer types for window_size robustly.
def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks):
if window_size is None:
wl, wr = -1, -1
elif isinstance(window_size, int):
wl, wr = window_size, window_size
else:
wl, wr = window_size[0], window_size[1]
return {
"qd": str(q.dtype),
"kd": str(k_cache.dtype),
"vd": str(v_cache.dtype),
"qh": int(q.shape[-2]),
"kh": int(k_cache.shape[-2]),
"hd": int(q.shape[-1]),
"vh": int(v_cache.shape[-1]),
"pb": int(k_cache.shape[-3]),
"c": int(bool(causal)),
"wl": int(wl),
"wr": int(wr),
"sc": int(softcap > 0.0),
"sk": int(sinks is not None),
"sgl": getattr(sgl_ops, "__version__", "unknown"),
}| def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q): | ||
| batch_size = int(page_table.shape[0]) | ||
| max_q_len = _flash_attn_max_q_len(q, max_seqlen_q) | ||
| max_kv_len = int(page_table.shape[1]) | ||
| return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len |
There was a problem hiding this comment.
In _flash_attn_kvcache_run_key, using 1_000_000 as the multiplier for max_q_len can lead to key collisions for ultra-long context lengths (where max_kv_len exceeds 1,000,000 tokens). To support context lengths of 1M+ tokens robustly, we should increase the multipliers for max_q_len and batch_size.
| def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q): | |
| batch_size = int(page_table.shape[0]) | |
| max_q_len = _flash_attn_max_q_len(q, max_seqlen_q) | |
| max_kv_len = int(page_table.shape[1]) | |
| return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len | |
| def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q): | |
| batch_size = int(page_table.shape[0]) | |
| max_q_len = _flash_attn_max_q_len(q, max_seqlen_q) | |
| max_kv_len = int(page_table.shape[1]) | |
| return batch_size * 1_000_000_000_000_000 + max_q_len * 100_000_000 + max_kv_len |
No description provided.