[MLX] Gemma4-31B ondevice sampling#20561
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20561
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 Pending, 2 Unclassified FailuresAs of commit 373b79d with merge base 825bd30 ( UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| print("Done.") | ||
|
|
||
|
|
||
| class _MLXSampleWrapper(nn.Module): |
There was a problem hiding this comment.
Ideally SamplingHead from backends/mlx/llm would be directly usable. The fact that we keep defining wrappers suggest that maybe it should be refactored to work?
In terms of shape, I expect sampling to work on (B, vocab) and return (B) (tokens).
I don't see a case for (B, S, vocab).
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 18s —— View job Code review
Overall this is a clean, well-guarded port of the Qwen3.5 MoE on-device sampling pattern (#20545). The runtime guards, range-checks, and the per-token seed schedule are consistent with that PR. A few notes below. On @metascroy's question: why not reuse
|
Summary
Lets the MLX-exported Gemma 4 31B model sample the next token on-device instead of returning logits for host-side sampling. Sampling is opt-in at export (--sample); temperature, top_p, and seed are runtime inputs, and the runner increments the seed per token.
Changes
top_p, seed) → int64 token; records a use_sampling constant-method flag. Non-sample export unchanged.
only for sampling models (non-sample keep seed 0, so they don't trip the guard).