Skip to content

feat: opt autotuner#1373

Open
blueswhen wants to merge 1 commit into
mainfrom
opt_autotuner
Open

feat: opt autotuner#1373
blueswhen wants to merge 1 commit into
mainfrom
opt_autotuner

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces autotuning support for flash_attn_with_kvcache using a custom autotuner wrapper, integrates boundary synchronization to minimize tuning interference, and updates the Qwen2-VL model to handle text-only batches during decode. Feedback on these changes highlights potential issues in the autotuning utility: a possible out-of-bounds indexing error when benchmarking short sequence lengths, a potential TypeError if window_size is not a subscriptable sequence, and key collisions in the run key generator for ultra-long context lengths.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +124 to +132
def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []

max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []

return [10240, max_kv_len]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _flash_attn_decode_bench_kv_lens, returning a hardcoded benchmark sequence length of 10240 when max_kv_len is smaller than 10240 will cause out-of-bounds indexing on the page_table tensor during autotuning. This can lead to illegal memory accesses or GPU page faults. We should cap the benchmark sequence length to max_kv_len if it is smaller than 10240.

Suggested change
def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []
max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []
return [10240, max_kv_len]
def _flash_attn_decode_bench_kv_lens(page_table):
if page_table is None or page_table.dim() < 2:
return []
max_kv_len = int(page_table.shape[1])
if max_kv_len <= 0:
return []
if max_kv_len < 10240:
return [max_kv_len]
return [10240, max_kv_len]

Comment on lines +46 to +62
def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks):
return {
"qd": str(q.dtype),
"kd": str(k_cache.dtype),
"vd": str(v_cache.dtype),
"qh": int(q.shape[-2]),
"kh": int(k_cache.shape[-2]),
"hd": int(q.shape[-1]),
"vh": int(v_cache.shape[-1]),
"pb": int(k_cache.shape[-3]),
"c": int(bool(causal)),
"wl": int(window_size[0]),
"wr": int(window_size[1]),
"sc": int(softcap > 0.0),
"sk": int(sinks is not None),
"sgl": getattr(sgl_ops, "__version__", "unknown"),
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _flash_attn_kvcache_static_key, window_size is assumed to be a subscriptable sequence of at least 2 elements. However, in some contexts or configurations, window_size can be passed as an integer (for symmetric window size) or None. Accessing window_size[0] directly will raise a TypeError in those cases. We should handle None and integer types for window_size robustly.

def _flash_attn_kvcache_static_key(q, k_cache, v_cache, causal, window_size, softcap, sinks):
    if window_size is None:
        wl, wr = -1, -1
    elif isinstance(window_size, int):
        wl, wr = window_size, window_size
    else:
        wl, wr = window_size[0], window_size[1]
    return {
        "qd": str(q.dtype),
        "kd": str(k_cache.dtype),
        "vd": str(v_cache.dtype),
        "qh": int(q.shape[-2]),
        "kh": int(k_cache.shape[-2]),
        "hd": int(q.shape[-1]),
        "vh": int(v_cache.shape[-1]),
        "pb": int(k_cache.shape[-3]),
        "c": int(bool(causal)),
        "wl": int(wl),
        "wr": int(wr),
        "sc": int(softcap > 0.0),
        "sk": int(sinks is not None),
        "sgl": getattr(sgl_ops, "__version__", "unknown"),
    }

Comment on lines +73 to +77
def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _flash_attn_kvcache_run_key, using 1_000_000 as the multiplier for max_q_len can lead to key collisions for ultra-long context lengths (where max_kv_len exceeds 1,000,000 tokens). To support context lengths of 1M+ tokens robustly, we should increase the multipliers for max_q_len and batch_size.

Suggested change
def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000 + max_q_len * 1_000_000 + max_kv_len
def _flash_attn_kvcache_run_key(q, page_table, max_seqlen_q):
batch_size = int(page_table.shape[0])
max_q_len = _flash_attn_max_q_len(q, max_seqlen_q)
max_kv_len = int(page_table.shape[1])
return batch_size * 1_000_000_000_000_000 + max_q_len * 100_000_000 + max_kv_len

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant