Skip to content

Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics for Wgrad#3143

Open
vthumbe1503 wants to merge 11 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling
Open

Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics for Wgrad#3143
vthumbe1503 wants to merge 11 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling

Conversation

@vthumbe1503

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 3 commits June 25, 2026 00:40
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Removed details about FP8 current scaling methods.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 25, 2026 00:57
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends the graph-safe GroupedLinear path (both the module and ops implementations) to support FP8 per-tensor current scaling on Hopper (CC 9.0) and fixes incorrect cuBLAS grouped GEMM heuristics that were supplying wrong dimension hints for the wgrad GEMM.

  • FP8 current-scaling graph-safe path: Float8CurrentScalingQuantizer is now recognised as a valid quantizer for the grouped-tensor GEMM path on Hopper and Blackwell; the Blackwell-only guard that previously rejected it has been moved after an early-return for current-scaling. The bug where rowwise_data/scale_inv were cleared unconditionally (destroying forward activations before save_for_backward when columnwise_data is None) is fixed by adding the grouped_x.columnwise_data is not None guard in both module and ops codepaths.
  • cuBLAS wgrad heuristics fix: nvte_grouped_gemm_with_discrete_out was computing avg_m/avg_n/avg_k from the wrong operands for the NT-layout wgrad GEMM call; avg_k now correctly uses first_dim(A) (token count, the contracted dimension) instead of last_dim(A) (in_features), and avg_m/avg_n sources are swapped to match the row-major/column-major convention used internally.
  • Test coverage: New fp8_current_scaling parametrised cases added to test_grouped_linear and test_grouped_mlp graph-safe tests, with Hopper-compatible skip logic replacing the previous blanket Blackwell-only skip.

Confidence Score: 5/5

The changes are well-scoped and additive: the correctness fix for the rowwise_data guard is straightforward, the heuristics change only affects algorithm selection (not numerical output), and the test additions provide direct coverage of the new path.

The core logic change — guarding rowwise_data/scale_inv clearing behind columnwise_data is not None — is clearly correct and resolves the previously flagged bug. The cuBLAS heuristic corrections are performance-only (wrong hints lead to suboptimal algorithm selection, not wrong results). No data-loss or correctness regressions are introduced by this PR.

tests/pytorch/test_grouped_mlp.py contains a stale skip message; all other files look correct.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Fixes wrong avg_m/avg_n/avg_k heuristic sources in nvte_grouped_gemm_with_discrete_out for the wgrad GEMM case: avg_k now correctly uses first_dim(A) (token count — the true contracted dimension) instead of last_dim(A) (in_features) for the NT-layout wgrad call.
transformer_engine/pytorch/module/grouped_linear.py Correctly guards rowwise_data/scale_inv clearing with columnwise_data is not None, fixing the unconditional clear that destroyed FP8 per-tensor current-scaling activations before save_for_backward. Extends _is_grouped_tensor_path_supported to allow Float8CurrentScalingQuantizer on Hopper by returning True before the Blackwell-only check.
transformer_engine/pytorch/ops/basic/grouped_linear.py Mirrors the module-level guard fix, adds single_grouped_weight parameter to _is_graph_safe_path_supported to restrict NVFP4 to discrete weights, removes the erroneous float8_current_scaling runtime error for single_grouped_weight, and imports Float8CurrentScalingQuantizer.
tests/pytorch/test_grouped_linear.py Adds fp8_current_scaling parametrized cases to two graph-safe grouped linear tests and refines the Blackwell-only skip to allow current-scaling on Hopper (SM90).
tests/pytorch/test_grouped_mlp.py Expands cuda_graph_safe test to cover fp8_current_scaling and nvfp4_rht, adds BF16-only and discrete-weight-only skips for NVFP4, but retains a stale skip message that still mentions current scaling after the condition was narrowed to delayed scaling only.

Reviews (6): Last reviewed commit: "Merge branch 'main' into nvfp4_and_fp8_c..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
vthumbe1503 and others added 5 commits June 26, 2026 17:26
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
… weight being cuda graphable

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…3/TransformerEngine into nvfp4_and_fp8_current_scaling
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 changed the title Graph Safe Current Scaling Support for GroupedLinear Module/Ops Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics for Wgrad Jun 27, 2026
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant