[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135denera wants to merge 10 commits into
Conversation
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes. A single CUDA kernel launch walks 128x128 tiles
across every tensor in the group, with each CTA decoding its owning
tensor from the device-side GroupedTensor metadata.
Supported shape representations:
- SAME_BOTH_DIMS (all tensors identical)
- VARYING_FIRST_DIM (constant K, varying R - the common MoE topology)
Supported directions: rowwise-only, columnwise-only, and both.
These kernels are gated to Hopper (sm_90) at the host dispatcher because
the consumer cuBLAS FP8 block-scaling *grouped* GEMM is itself
Hopper-only (cuBLAS does not provide native FP8 block-scaling grouped
GEMM on Blackwell; the recommended quantization recipe on Blackwell is
MXFP8). The device-side kernel bodies are gated on __CUDA_ARCH__ >= 900
so the kernels compile and link as part of multi-arch builds, but the
host gate prevents launches on Blackwell.
Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:
| Kernel | Dispatched when | Threading | Smem |
|--------|-----------------|-----------|------|
| group_block_scaled_1d_rw_kernel | 1D RW-only | 8 threads/row x 32 row-warps x 4 iters; reads gmem directly into vec-16 registers | none |
| group_block_scaled_1d_tma_kernel | 1D CW or 1D BOTH | TMA bulk-load fills 32 KB input cache. BOTH runs RW pass first (8 t/row, vec-16) then CW pass (2 t/col, 64-row register stage); CW-only skips the RW pass. CW writes the transposed-FP8 tile to a 16.5 KB smem_T staging buffer, then drains to gmem. | 32 KB + 16.5 KB |
| group_block_scaled_2d_tma_kernel | 2D RW / CW / BOTH | TMA bulk-load fills 32 KB cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits rowwise output, stages columnwise output to smem_T, then drains. | 32 KB + 16.5 KB |
The RW-only 1D path bypasses TMA because a streaming read has no reuse
- the smem round-trip and mbarrier overhead would just add latency.
The C++ test tests/cpp/operator/test_cast_float8blockwise_grouped.cu
exercises 72 configurations covering RW/CW/BOTH x 1D/2D x SAME/VARYING
shape representations against a per-tensor split-quantize reference.
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
| constexpr int kThreadsPerBlock = 256; | ||
| constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; | ||
|
|
||
| // Align a dynamic-smem pointer to 128 bytes (TMA requirement). |
There was a problem hiding this comment.
Could we reuse the existing align_smem_ptr_per_TMA_requirements() helper from transformer_engine/cast/core/common.h here?
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; | ||
| const size_t num_tiles_X = | ||
| (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X; |
There was a problem hiding this comment.
We can also reuse the existing DIVUP() helper here (defined in transformer_engin/common/common.h).
|
|
||
| // ---- Tensor-lookup helpers ---------------------------------------------------- | ||
|
|
||
| // Map a global tile-row index to its owning tensor by binary-searching |
There was a problem hiding this comment.
We can also reuse the existing get_current_tensor_id() helper defined in transformer_engine/cast/core/common.cuh
Greptile SummaryThis PR adds fused grouped-quantize CUDA kernels for FP8 1D (1×128) and 2D (128×128) block-scaling, targeting Hopper (SM90–SM99). A single kernel launch walks 128×128 tiles across every tensor in a group, decoding per-expert tensor metadata from device-side
Confidence Score: 5/5Safe to merge; the new CUDA kernels are well-structured, per-expert scale layouts are consistent between quantize and dequantize, and host-side allocations correctly account for per-expert alignment slack. The three new kernels cover all (1D/2D × RW/CW/BOTH) combinations with consistent layout logic verified across quantize and dequantize sides. PyTorch dispatch wiring, scale-buffer sizing, and the force_pow_2_scales guard are all correct. The CC guard change in grouped_linear.py and the smem-attribute condition are non-blocking concerns that do not affect correctness or stability. grouped_linear.py for the dropped Hopper CC guard; group_quantize_fp8_blockwise.cuh for the cudaFuncSetAttribute condition. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[group_quantize / bgrad_group_quantize] --> B{scaling_mode?}
B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
B -->|NVTE_MXFP8| E[group_quantize_mxfp8]
B -->|NVTE_NVFP4| F[group_quantize_transpose_nvfp4]
C --> G{use_colwise or dbias?}
G -->|No| H[group_block_scaled_1d_rw_kernel no smem vec-16 global loads]
G -->|Yes| I[group_block_scaled_1d_tma_kernel TMA bulk-load smem input cache]
D --> J[group_block_scaled_2d_tma_kernel Pass 1 tile amax via register staging Pass 2 quantize + smem_T transpose drain]
I --> K[RW pass 8 t/row vec-16 from smem rowwise gmem]
I --> L[CW pass 2 t/col 64-row regs smem_T colwise gmem]
J --> M[Rowwise gmem write from registers]
J --> N[smem_T XOR-swizzle stage colwise gmem drain]
C & D --> O[grouped_reduce_dbias if dbias requested]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[group_quantize / bgrad_group_quantize] --> B{scaling_mode?}
B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
B -->|NVTE_MXFP8| E[group_quantize_mxfp8]
B -->|NVTE_NVFP4| F[group_quantize_transpose_nvfp4]
C --> G{use_colwise or dbias?}
G -->|No| H[group_block_scaled_1d_rw_kernel no smem vec-16 global loads]
G -->|Yes| I[group_block_scaled_1d_tma_kernel TMA bulk-load smem input cache]
D --> J[group_block_scaled_2d_tma_kernel Pass 1 tile amax via register staging Pass 2 quantize + smem_T transpose drain]
I --> K[RW pass 8 t/row vec-16 from smem rowwise gmem]
I --> L[CW pass 2 t/col 64-row regs smem_T colwise gmem]
J --> M[Rowwise gmem write from registers]
J --> N[smem_T XOR-swizzle stage colwise gmem drain]
C & D --> O[grouped_reduce_dbias if dbias requested]
Reviews (4): Last reviewed commit: "Add grouped FP8 block-scaling dequantize..." | Re-trigger Greptile |
| } | ||
|
|
||
| CType amax = compute_row_amax<IType, CType, kVec>(in_vec[it]); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); |
There was a problem hiding this comment.
Could we reuse the existing amax warp-reduction helpers (warp_reduce_max() or reduce_max()) from transformer_engine/common/utils.cuh here?
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4)); |
There was a problem hiding this comment.
We can also reuse reduce_max() or warp_reduce_max() here.
|
|
||
| // ----- Host-side dispatchers -------------------------------------------------------------------- | ||
|
|
||
| inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; } |
There was a problem hiding this comment.
We can reuse DIVUP_TO_MULTIPLE() defined in transformer_engine/common/common.h.
| NVTE_CHECK(info.tensor_offsets_d != nullptr, | ||
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | |
| info.total_row_blocks = DIVUP(info.R_total, kTileDim); |
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | ||
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; | |
| info.blocks_X = DIVUP(info.K, kTileDim); |
| info.same_both_dims = same_both_dims; | ||
| info.num_tensors = output->num_tensors; | ||
| info.K = output->get_common_last_dim(); | ||
| NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment)."); |
There was a problem hiding this comment.
If this is a TMA requirement, we can use the TMA_GMEM_ALIGNMENT constant defined in transformer_engine/common/common.h
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | |
| const size_t scale_stride_y = DIVUP_TO_MULTIPLE(info.blocks_X, 4); |
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | ||
| // CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to | ||
| // match the physically-transposed columnwise data the TN cuBLAS GEMM consumes. | ||
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); | |
| const size_t scale_t_stride_y = DIVUP_TO_MULTIPLE(info.total_row_blocks, 4); |
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | |
| const size_t scale_stride_aligned_R = DIVUP_TO_MULTIPLE(info.R_total, 4); |
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | ||
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); | |
| const size_t scale_t_stride_aligned_K = DIVUP_TO_MULTIPLE(info.K, 4); |
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT, align_smem_ptr_per_TMA_requirements, get_current_tensor_id, subwarp_reduce_max_broadcast) in place of local equivalents. - Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels. - Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM (matches MXFP8 grouped quantize behavior). - Fix Hopper SM range wording in 1D dispatcher. - Extend cpp tests to cover with_gemm_swizzled_scales path. Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
| // num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4) | ||
| __device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, | ||
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; |
There was a problem hiding this comment.
I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling
The swizzle helpers are shared across MXFP8, NVFP4, and FP8 block scaling. Relocate swizzle.cuh from cast/mxfp8/ to cast/ and drop the mxfp8:: namespace layer so callers don't reach across precisions. Signed-off-by: Alp Dener <adener@nvidia.com>
create_grouped_tensor hardcoded with_gemm_swizzled_scales=false, so the swizzled-scale kernel branch was unreachable from the PyTorch grouped quantize API even when the quantizer requested optimize_for_gemm. Plumb optimize_for_gemm into both the C++ wrapper and the Python GroupedTensor kwarg, and size the colwise scale buffer for the swizzled layout via a new with_gemm_swizzled_scales parameter on get_scale_shape. Add a parameterized PyTorch test that asserts the flag propagation end-to-end so a future regression of this kind fails loudly. Also fixes two swizzle.cuh includes missed in the prior namespace-rename commit (mxfp8/dequantize_mxfp8.cuh + mxfp8/quantize_mxfp8.cuh). Signed-off-by: Alp Dener <adener@nvidia.com>
…EMM swizzle cuBLAS FP8 block-scaling grouped GEMM reads each expert's scales from a contiguous compact per-expert sub-block, so the kernels write that layout directly and size the scale buffers from the logical totals — no per-tensor device->host copy, which keeps allocation CUDA-graph-safe. The MXFP8-style GEMM-swizzled scale layout is never consumed by FP8 block scaling, so its kernel template, per-expert index helper, and dispatch are removed. (The internal bank-conflict shared-memory swizzle is unrelated and kept.) Enable FP8 block scaling on the GroupedLinear CUDA-graph-safe path, Hopper-only (cuBLAS rejects it on Blackwell, where MXFP8 is preferred). Signed-off-by: Alp Dener <adener@nvidia.com>
Restore tests/pytorch/test_grouped_linear.py and transformer_engine/pytorch/module/grouped_linear.py to their main-branch state. The FP8 block-scaling GroupedTensor + Quantizer plumbing stays in this branch; the module-level wiring will be revisited in a follow-up PR alongside the cuBLAS H100 grouped-WGRAD miscompute investigation. Signed-off-by: Alp Dener <adener@nvidia.com>
Adds three TEST_P variants (CudaGraphCapture, CudaGraphCaptureDiscreteOut, CudaGraphCaptureDiscreteIn) that capture each of the three grouped GEMM API entry points on a non-blocking stream, instantiate a CUDA graph, and replay twice with per-replay output verification. Asserts the pure C++ implementation is graph-safe across all existing parametrizations (recipes, layouts, shape cases, NULL-C, output dtypes). Signed-off-by: Alp Dener <adener@nvidia.com>
Extend the GroupedTensor FP8 block-scaling path to match MXFP8's grouped quantize/dequantize/bgrad coverage on Hopper. Dequantize: * New group_dequantize_fp8_blockwise.cuh with kernels for all four modes (1D/2D x rowwise/columnwise), inverting the per-expert layouts that the grouped quantize kernels write. Wired into the group_dequantize dispatch. * group_dequantize (PyTorch) derives the scale_inv dtype from the scaling mode (FP8BS -> Float32) instead of hardcoding E8M0. Bias gradient: * bgrad_group_quantize now accepts Float8Block quantizers. dbias is computed in-kernel as a per-tile column partial (mirroring MXFP8) and reduced per expert via the existing common::grouped_reduce_dbias; no separate reduction kernel. The TMA quantize kernels compute the partial from the smem-resident tile; RW-only-with-dbias is routed through the TMA kernel. Scale constraints: * The fused grouped FP8BS path supports only unconstrained FP32 scales. force_pow_2_scales=True is rejected in Float8BlockQuantizer::create_grouped_tensor, and the pow_2_scaling code path is stripped from the grouped quantize kernels. Power-of-2 scales remain available on the non-grouped / unfused split-quantize path (used for the Blackwell MXFP8 emulation), which is unchanged. Tests: * Consolidate test_grouped_tensor.py so MXFP8, NVFP4, and FP8 block scaling share parametrized quantize / dequantize / bgrad tests across tensor-usage directions, block-scaling dims, and output dtypes. Signed-off-by: Alp Dener <adener@nvidia.com>
Description
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports
SAME_BOTH_DIMS(all tensors identical) andVARYING_FIRST_DIM(constant K, varying R) shape representations.Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:group_block_scaled_1d_rw_kernel— RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip andptx::mbarrierdoes not buy anything without re-use in CW path.group_block_scaled_1d_tma_kernel— CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.group_block_scaled_2d_tma_kernel— RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).
PR includes PyTorch integration.
JAX integration is intentionally left out-of-scope and deferred to a follow-up PR because it requires non-trivial new scaffolding on the framework side.
Resolves #2525
Performance
Table below measures performance on H200 with a sweep of grouped tensors in (N, M, K) shapes with:
The shapes are split into two buckets:
Reported kernel times and throughput ratios are bucket medians.
Speedup is measured relative to the split-quantized fallback that loops over the grouped tensor and sequentially quantizes each one.
% of "mono" throughput is measured relative to the throughput of a single non-grouped FP8 block-scaling quantize kernel invoked with the equivalent monolithic (NxM, K) tensor where the # of experts are collapsed with # of tokens/expert.
Notes
Known Sub-Optimalities
1D CW has bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)
CU_TENSOR_MAP_SWIZZLE_128Bhas the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).1D BOTH reads the shared memory input-cache twice
2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)
No TMA-store
Type of change
Checklist: