Add top-k support to MLX sample#20564
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20564
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
@pytorchbot label "release notes: llm" |
| return torch.ops.mlx.sample(last, temperature, top_p, seed) | ||
| if top_k is not None and not isinstance(top_k, torch.Tensor): | ||
| top_k = torch.tensor(int(top_k), dtype=torch.int64) | ||
| return torch.ops.mlx.sample(last, temperature, top_p, seed, top_k) |
There was a problem hiding this comment.
Nit: Can we reorder the args here to be temp, top_k, top_p, seed?
We also need to modify this custom ops behavior to work with top_k correctly now.
Do huggingface style, where topk happens before topp, which requires renormalization, roughly something like this.
Something similar is required on the emit path.
def sample(logits, temperature, top_k=None, top_p=1.0, seed=None):
if float(temperature) <= 0:
return torch.argmax(logits, dim=-1)
scaled = logits.float() / temperature
# ── Top-k FIRST (on logits; monotonicity ⇒ top-k logits = top-k probs) ──
if top_k is not None:
k = int(top_k.item())
s_scaled, _ = torch.sort(scaled, dim=-1, descending=True)
kth = s_scaled[..., k - 1 : k]
scaled = torch.where(scaled >= kth, scaled, scaled.new_tensor(float("-inf")))
# ── Top-p on the *renormalized* distribution ──
probs = torch.softmax(scaled, dim=-1) # exp(-inf)=0 → renormalized over top-k
s_probs, _ = torch.sort(probs, dim=-1, descending=True)
cum = torch.cumsum(s_probs, dim=-1)
keep = (cum - s_probs) <= top_p
thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin(
dim=-1, keepdim=True
)
scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf")))
# ── Gumbel-max ──
if seed is None:
u = torch.rand(scaled.shape)
else:
gen = torch.Generator().manual_seed(int(seed.item()))
u = torch.rand(scaled.shape, generator=gen)
gumbel = -torch.log(-torch.log(u))
return torch.argmax(scaled + gumbel, dim=-1)
|
|
||
| def test_top_k_restricts_to_top_k(self): | ||
| # probs [0.5, 0.3, 0.15, 0.05]; top_k=2 keeps {0,1}. | ||
| base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) |
There was a problem hiding this comment.
Can we have test case in non-sorted order, e.g., 0.15, 0.5, .05, 0.3?
| out=P.slot_to_tid(drop), | ||
| ) | ||
| ) | ||
| if top_k is not None: |
There was a problem hiding this comment.
Let's keep top_k specified at runtime, with default value int max (select everything)
| ) | ||
| if top_k is not None: | ||
| _, top_k_val = P.make_tmp_value_slot() | ||
| P.emit(ItemIntNode(x=P.slot_to_tid(top_k), out=P.slot_to_vid(top_k_val))) |
There was a problem hiding this comment.
Let's do Int clipping to vocab size if topk > vocab size.
| if top_k is not None: | ||
| _, top_k_val = P.make_tmp_value_slot() | ||
| P.emit(ItemIntNode(x=P.slot_to_tid(top_k), out=P.slot_to_vid(top_k_val))) | ||
| top_k_iov = P.to_int_or_vid(top_k_val) |
There was a problem hiding this comment.
top_k_val is already intorvid?
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 58s —— View job Code review: Add top-k support to MLX sampleI read through all five changed files plus the emit path in 1. Top-k / top-p composition diverges from HuggingFace semantics (correctness) — confirms Scott's
|
|
thanks for your quick review @metascroy will take a look and address them. |
|
@metascroy Addressed below changes in the latest commit.
|
|
@claude review this code |
Fixes #20548
Summary
top_kthroughSamplingHeadandmlx::sample.Test plan
PYTHONPATH=src:. python3 -m unittest executorch.backends.mlx.test.test_sample.TestSampleOp executorch.backends.mlx.test.test_sample.TestSampleExportpython3 -m compileall -q backends/mlx/llm/sampling.py backends/mlx/custom_ops.py backends/mlx/ops.py backends/mlx/test/test_sample.py backends/mlx/test/test_ops.py