Skip to content

Add top-k support to MLX sample#20564

Open
goutamadwant wants to merge 3 commits into
pytorch:mainfrom
goutamadwant:fix-mlx-sample-top-k
Open

Add top-k support to MLX sample#20564
goutamadwant wants to merge 3 commits into
pytorch:mainfrom
goutamadwant:fix-mlx-sample-top-k

Conversation

@goutamadwant

Copy link
Copy Markdown

Fixes #20548

Summary

  • Thread optional top_k through SamplingHead and mlx::sample.
  • Apply top-k filtering as an additional threshold mask that composes with the existing top-p nucleus mask.
  • Add eager, export, end-to-end, and MLX lowering coverage for top-k sampling.

Test plan

  • PYTHONPATH=src:. python3 -m unittest executorch.backends.mlx.test.test_sample.TestSampleOp executorch.backends.mlx.test.test_sample.TestSampleExport
  • python3 -m compileall -q backends/mlx/llm/sampling.py backends/mlx/custom_ops.py backends/mlx/ops.py backends/mlx/test/test_sample.py backends/mlx/test/test_ops.py

@pytorch-bot

pytorch-bot Bot commented Jun 27, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20564

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 27, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 27, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

  • ✅ login: goutamadwant / name: goutamadwant (c944a5a)

@goutamadwant

Copy link
Copy Markdown
Author

@pytorchbot label "release notes: llm"

Comment thread backends/mlx/llm/sampling.py Outdated
return torch.ops.mlx.sample(last, temperature, top_p, seed)
if top_k is not None and not isinstance(top_k, torch.Tensor):
top_k = torch.tensor(int(top_k), dtype=torch.int64)
return torch.ops.mlx.sample(last, temperature, top_p, seed, top_k)

@metascroy metascroy Jun 27, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we reorder the args here to be temp, top_k, top_p, seed?

We also need to modify this custom ops behavior to work with top_k correctly now.

Do huggingface style, where topk happens before topp, which requires renormalization, roughly something like this.

Something similar is required on the emit path.

def sample(logits, temperature, top_k=None, top_p=1.0, seed=None):
    if float(temperature) <= 0:
        return torch.argmax(logits, dim=-1)

    scaled = logits.float() / temperature

    # ── Top-k FIRST (on logits; monotonicity ⇒ top-k logits = top-k probs) ──
    if top_k is not None:
        k = int(top_k.item())
        s_scaled, _ = torch.sort(scaled, dim=-1, descending=True)
        kth = s_scaled[..., k - 1 : k]
        scaled = torch.where(scaled >= kth, scaled, scaled.new_tensor(float("-inf")))

    # ── Top-p on the *renormalized* distribution ──
    probs = torch.softmax(scaled, dim=-1)          # exp(-inf)=0 → renormalized over top-k
    s_probs, _ = torch.sort(probs, dim=-1, descending=True)
    cum = torch.cumsum(s_probs, dim=-1)
    keep = (cum - s_probs) <= top_p
    thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin(
        dim=-1, keepdim=True
    )
    scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf")))

    # ── Gumbel-max ──
    if seed is None:
        u = torch.rand(scaled.shape)
    else:
        gen = torch.Generator().manual_seed(int(seed.item()))
        u = torch.rand(scaled.shape, generator=gen)
    gumbel = -torch.log(-torch.log(u))
    return torch.argmax(scaled + gumbel, dim=-1)

Comment thread backends/mlx/test/test_sample.py Outdated

def test_top_k_restricts_to_top_k(self):
# probs [0.5, 0.3, 0.15, 0.05]; top_k=2 keeps {0,1}.
base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have test case in non-sorted order, e.g., 0.15, 0.5, .05, 0.3?

Comment thread backends/mlx/ops.py Outdated
out=P.slot_to_tid(drop),
)
)
if top_k is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep top_k specified at runtime, with default value int max (select everything)

Comment thread backends/mlx/ops.py Outdated
)
if top_k is not None:
_, top_k_val = P.make_tmp_value_slot()
P.emit(ItemIntNode(x=P.slot_to_tid(top_k), out=P.slot_to_vid(top_k_val)))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do Int clipping to vocab size if topk > vocab size.

Comment thread backends/mlx/ops.py Outdated
if top_k is not None:
_, top_k_val = P.make_tmp_value_slot()
P.emit(ItemIntNode(x=P.slot_to_tid(top_k), out=P.slot_to_vid(top_k_val)))
top_k_iov = P.to_int_or_vid(top_k_val)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k_val is already intorvid?

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 27, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 58s —— View job


Code review: Add top-k support to MLX sample

I read through all five changed files plus the emit path in ops.py and the helpers it relies on. Scott's inline comments already cover the main direction; below I confirm them with a concrete numeric case and add a few additional findings.

1. Top-k / top-p composition diverges from HuggingFace semantics (correctness) — confirms Scott's sampling.py:41 comment

Both the reference op (custom_ops.py) and the emit path (ops.py) compute the top-k mask and the top-p mask independently on the same un-renormalized probs and then OR the drop masks together. HF does top-k first, renormalizes, then applies top-p to the renormalized distribution. These are not equivalent. Worked example, probs = [0.5, 0.3, 0.15, 0.05], top_k=3, top_p=0.83:

  • This PR: top-k keeps {0,1,2}, top-p (prefix mass [0, 0.5, 0.8, 0.95] ≤ 0.83) keeps {0,1,2} → final keep {0,1,2}.
  • HF: top-k → renormalize over 0.95 → [0.526, 0.316, 0.158]; top-p on that (prefix [0, 0.526, 0.842] ≤ 0.83) keeps {0,1} → final keep {0,1}.

So the result genuinely differs once both filters are active. Scott's suggested sample() rewrite (top-k on logits → softmax/renormalize → top-p on the renormalized probs) is the fix, and the same restructuring is needed on the emit path — the current LogicalOrNode combine mirrors the reference's OR-compose, so both must move to sequential top-k-then-renormalize-then-top-p. The new test_top_k_and_top_p_compose test passes only because the chosen values happen to agree under both semantics; a case like the one above would catch the divergence.

2. top_k > vocab_size is an out-of-bounds bug (correctness) — relates to Scott's ops.py:3681 clipping comment

custom_ops.py:431:

k = int(top_k.item())
kth = s_probs[..., k - 1 : k]

If k > vocab, the slice [k-1:k] is empty, so probs >= kth broadcasts against a size-0 tensor and the mask is wrong/raises. On the emit path the corresponding TakeNode(index = top_k - 1) indexes out of bounds. Clipping k = min(k, vocab) (Scott's point) fixes both. Worth a test with top_k larger than the vocab to lock this in.

3. top_k as runtime default instead of Optional — Scott's ops.py:3679 comment

Making top_k always-present with a default of INT_MAX (clipped to vocab per #2) removes the if top_k is not None branch in both custom_ops.py and the _sample_handler emit block, and the is not None checks in sampling.py. That collapses two code paths into one and naturally yields "keep everything" when unset.

4. Argument ordering — Scott's sampling.py:41 comment

Reorder to temperature, top_k, top_p, seed consistently across mlx::sample, sample_fake, _sample_handler's args[...] unpacking, and the _sample test helper. Note this is a signature change to the registered custom op, so every call site and the require_args(args, 3, 5, ...) indices need to move together.

5. to_int_or_vid already handles the literal/vid split — Scott's ops.py:3682 comment

ops.py:3688-3692:

index=(
    IntOrVidOrTid.from_vid(top_k_index.vid)
    if top_k_index.is_vid
    else IntOrVidOrTid.from_literal(top_k_index.literal)
)

top_k_val comes from ItemIntNode so it's always a Vid here — top_k_iov/top_k_index will never be a literal in this path, making the else branch dead. P.to_int_or_vid_or_tid(top_k_val) (program_builder.py:318) already does this conversion in one call; prefer it over the manual ternary.

6. Test coverage — Scott's test_sample.py:165 comment

  • Add a non-sorted-logits case (e.g. probs 0.15, 0.5, 0.05, 0.3) so the test exercises the sort path rather than already-descending input.
  • Add a top_k > vocab case (Re-sync with internal repository #2) and a combined case where HF vs OR-compose semantics actually differ (Add support for quantized LeakyReLU #1), so the renormalization fix is verified rather than coincidentally passing.

Minor

  • custom_ops.py:431 int(top_k.item()) is fine for the eager/meta reference, but confirm the .item() data-dependent read doesn't trip export's symbolic tracing on this op (the seed path already uses ItemIntNode on device, so the device side is fine).
  • Docstring in sampling.py still reads cleanly, but once Rename _pt2e to pt2e #3 lands ("None disables top-k") the wording should reflect the always-present default instead.

Overall the lowering scaffolding (Take → ExpandDims → Less → LogicalOr) is well-built and the node-count test is a nice touch; the substantive work is switching both the reference and emit paths to HF-style sequential filtering and adding the vocab-size clip.
· branch fix-mlx-sample-top-k

@goutamadwant

Copy link
Copy Markdown
Author

thanks for your quick review @metascroy will take a look and address them.

@goutamadwant

Copy link
Copy Markdown
Author

@metascroy Addressed below changes in the latest commit.

  • Reordered the custom op args to temperature, top_k, top_p, seed.
  • Switched sampling to top-k first.. and then top-p on the renormalized probabilities.
  • Made the default top_k path use a max-int sentinel clipped to vocab size.
  • Removed the independent OR-composed top-k/top-p masks in the emit path.
  • Added coverage for non-sorted logits, top_k > vocab, and the combined top-k/top-p case where renormalization changes the result.
    let me know if you have any more suggestions. thanks!

@goutamadwant

Copy link
Copy Markdown
Author

@claude review this code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: llm Changes to llm utilities

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Good First Issue: top-k filtering for mlx::sample (MLX backend)

3 participants