Releases: NVIDIA/TransformerEngine
Releases · NVIDIA/TransformerEngine
v2.16.1
Transformer Engine v2.16.1 Patch Release Notes
Fixed Issues
- [PyTorch] Fixed an issue where the "extra state" of a checkpoint was unpickled even if the quantization recipe is stateless. To enhance security, add an explicit guard via environment variable in order to load the "extra state" portion of the checkpoints for stateful recipes. (#3123).
v2.16
Transformer Engine v2.16 Release Notes
Key Features and Enhancements
- [Common] Improved the performance of the split-overlap reduce-scatter GEMMs. (#2056)
- [Common] Improved the fused MoE auxiliary loss kernel performance for models with a large number of experts. (#2758)
- [Common] Optimized MXFP8 and NVFP4 dequantize kernels for improved performance. (#2865)
- [Common] Improved performance of the MXFP8 quantization kernels. (#2958)
- [PyTorch] Added
pad_between_seqssupport for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) attention. (#2596) - [PyTorch] Added role-based custom quantization control, enabling recipes to target specific modules and tensor types. (#2620)
- [PyTorch] Added end-to-end Mixtral MoE examples showing TE GroupedLinear integration with HuggingFace models for BF16 and FP8 training. (#2642)
- [PyTorch] Increased performance of the CPU activation offloading path in some cases (#2793)
- [PyTorch] Reduced the CPU overhead in the GroupedLinear module and operation (#2900) (#2957) (#2666)
- [PyTorch] Added CUDA Graph capture support for GroupedLinear and grouped MoE operations on supported configurations. (#2923)
- [PyTorch] Added FlashAttention 4 support for attention head dimension 256. (#2932)
- [JAX] Improved MoE permutation kernel performance. (#2975)
- [JAX] Improved JAX tutorial documentation with updated examples and guidance. (#2976)
- [Common, PyTorch] Added bias and dbias support for GroupedLinear layers. (#2885)
- [Common, PyTorch] Added variable grouped swizzle support for flexible grouped tensor memory layouts. (#2914)
- [Common, PyTorch] Implemented a row-scaled NVFP4 forward propagation recipe. (#2931)
- [Common, PyTorch] Expanded grouped GEMM support with NVFP4 on Blackwell and FP8 block scaling on Hopper. (#2971)
- [Common, JAX] Added a top-k operation for faster MoE routing. (#2890)
- [Common, JAX] Enabled the cuDNN fused attention backend for no-mask bidirectional sliding-window attention. (#2961)
Fixed Issues
- [PyTorch] Fixed variable-length attention cache reuse across devices and inference/training modes. (#2728)
- [PyTorch] Fixed FSDP2 memory leaks for FP8 weight workspaces and transpose caches. (#2805)
- [PyTorch] Fixed TE fuser behavior in torch.no_grad() paths by avoiding invalid gradient-flag updates on non-leaf tensors. (#2919)
- [PyTorch] Fixed distributed checkpoint loading for FSDP2 for models initialized with
QuantizedModelInit. (#2974) - [Common, PyTorch] Fixed cuBLAS grouped GEMM when weight dimensions are not divisible by 128. (#2954)
- [Common, PyTorch] Fixed int32 overflow and -1 sentinel value handling in
moe_permute. (#2907) - [Common, PyTorch] Fixed context-parallel FlashAttention output handling when FA3 is installed without FA2.(#2825)
- [Common, PyTorch] Disabled RHT quantization fusion on unsupported GPU architectures to avoid launch failures. (#2968)
- [PyTorch] Fixed a crash coming from GroupedLinear weight-gradient allocation. (#3049)
Breaking Changes in This Release
- [Common, PyTorch] The original FP8 delayed-scaling fused attention path has been removed. FP8 attention now uses the current cuDNN-backed implementation. (#2959)
- [Common, PyTorch, JAX] Removed the legacy f16_max512 fused-attention backend. BF16/FP16 attention is routed through the maintained arbitrary-sequence backend, but explicit selections of the old backend must be updated. (#2949)
Deprecated Features
There are no deprecated features in this release.
v2.15
Transformer Engine v2.15 Release Notes
Key Features and Enhancements
- [PyTorch] Added support for Flash Attention 4. (#2432)
- [PyTorch] Added support for MXFP8 attention. (#2719)
- [PyTorch] Added support for QGeGLU activation both in
te.opsand the fused grouped MLP path using GEMM + activation fusion. (#2855) - [PyTorch] Added support for per-token bias probability scaling both in
te.opsand the fused grouped MLP path using GEMM + activation fusion. (#2864) - [PyTorch] Added support for NVFP4 weight quantization in the fused Adam optimizer. (#2797)
- [PyTorch, Common] Added triton kernels to support mHC (Manifold-Constrained Hyper-Connections). (#2790)
- [PyTorch, Common] Added support for dequantizing MXFP8 grouped tensors. (#2722)
- [Common] Added support for unswizzling scaling factors. (#2837,#2732)
- [PyTorch] Added Newton–Schulz orthogonalization via cuSOLVERMp for distributed orthogonalization workloads. (#2706)
- [PyTorch] Added an
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedenvironment variable to control backward precision behavior. (#2644) - [PyTorch] Added a feature to debug tools to allow tensor dumps before and after quantization for numerical debugging. (#2645)
- [PyTorch] Optimized FP8 block-scaling AllGather for FSDP2 to reduce communication overhead. (#2789)
- [PyTorch] Added an example demonstrating high-precision weight initialization with
fully_shard. (#2785) - [PyTorch] Expanded fused grouped MLP support via
te.opsby lowering the weight dimension requirements to being divisible by 64 (previously 256). (#2856) - [PyTorch] Added
torch.compilesupport for the MoE permute utility functions. (#2686) - [Common, PyTorch] Improved the performance of NVFP4 quantization by refactoring the amax compute kernel. (#2820)
- [JAX] Reduced THD seqlen and offset computation from
O(T·T)memory down toO(T)for long sequences. (#2522) - [JAX] Added MXFP8 grouped quantize + GEMM support. (#2763)
Fixed Issues
- [PyTorch] Fixed a numerical bug where stale columnwise weight data would be used for post-validation training steps. (#2929)
- [PyTorch] Fixed redundant memory usage when using NVFP4 parameters. (#2834)
- [JAX] Fixed the JAX extension build with
NVTE_UB_WITH_MPI=1. (#2835) - [Common] Fixed a numerical bug for the MoE fused router for large top-K and expert counts. (#2821)
- [Common] Fixed an illegal memory access in
register_user_buffer_collectiveon Ampere (and older) GPUs when using user buffers for COMM-GEMM overlap. (#2859) - [Build] Fixed a build crash when compiling from source with
NVTE_CUDA_ARCHS=120. (#2832)
Known issues
- [PyTorch] When building a grouped MLP module via
te.ops.Sequentialin order to use the GEMM + activation fusion, the kernel may produce non-deterministic results in the single grouped-weight case (i.e., when the environment variableNVTE_GROUPED_LINEAR_SINGLE_PARAMand the corresponding module argumentsingle_grouped_weightis set). - [PyTorch] Enabling fused grouped MLP via
te.opsrequirescudnn-frontendlibrary version1.23.0. In case of issues please ensure that the right version ofCuTeDSLis correctly installed:
python -m pip uninstall -y \
cutlass \
nvidia-cutlass \
nvidia-cutlass-dsl \
nvidia-cutlass-dsl-libs-base \
nvidia-cutlass-dsl-libs-cu13 \
nvidia-cudnn-frontend
python -m pip install -U pip setuptools wheel
python -m pip install --no-cache-dir "nvidia-cutlass-dsl[cu13]==4.4.1"
python -m pip install --no-cache-dir "nvidia-cudnn-frontend[cutedsl]==1.23.0"
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
There are no deprecated features in this release.
v2.14.1
v2.14
Transformer Engine v2.14 Release Notes
Key Features and Enhancements
- [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
- [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
- [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
- [PyTorch] Added support for a single-parameter
GroupedLinearconfiguration, where the weights of all experts are stored in a single parameter, which reduces CPU overheads. (#2731) - [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter
GroupedLinear. (#2761) - [PyTorch] Extended the fused attention API to optionally return softmax
Statsalways andMaxwhenreturn_max_logit=True, exposing more cuDNN intermediates to users. (#2677) - [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
- [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in
FusedAdam. (#2753) - [PyTorch] Added CUDA graph-compatible
multi_tensor_scale_tensorAPI in the optimizer. (#2594) - [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
- [PyTorch] Added support for non-FP32
params_dtypewhen using QK-normalization. (#2718) - [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
- [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
- [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
- [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
- [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
- [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
- [C] Enabled the fused RMSNorm
dLN + addbackward path through cuDNN for faster fused-residual normalization. (#2778) - [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
- [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
- [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
- [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
- [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
- [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
- [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
- [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
- [Documentation] Added documentation for the operator fuser API. (#2447)
- [PyTorch, Documentation] Added end-to-end examples for
fused_adam,quantized_model_init, and FSDP2 usage. (#2698) (#2662)
Fixed Issues
- [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): when model parameters are
DTensors, ensure optimizer states are alsoDTensors for correct sharded checkpoints. (#2795) - [PyTorch] Fixed async DCP checkpointing for
Float8Tensorparameters. (#2721) - [PyTorch] Fixed the issue with
cross_entropy_forwardproducing wrong answers for non-contiguous logits. (#2746) - [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
- [PyTorch] Fixed a precision-debug-tools crash when
tp_group=None. (#2733) - [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
- [PyTorch] Fixed the initialization of the learnable
softmax_offsetparameter inDotProductAttentionto zero-initialization. (#2694) - [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
- [PyTorch] Added a clear error when constructing
LayerNormLinearwith row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688) - [JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)
- [JAX] Fixed the assertion error when using
from_segment_ids_and_pos()withvmap. (#2692) - [JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)
- [JAX] Changed the dtype of the intermediate-result aval in
fused_topk_and_score_function_fwdtofp32to avoid precision loss. (#2752) - [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)
- [C/PyTorch] Fixed score normalization in
fused_score_for_moe_aux_losswhentopk == 1. (#2720) - [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)
Breaking Changes in This Release
- [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)
Deprecated Features
There are no deprecated features in this release.
v2.13
Transformer Engine v2.13 Release Notes
Key Features and Enhancements
- Added detailed documentation for low precision training with Transformer Engine, covering FP8, MXFP8, NVFP4, and other quantization recipes with examples for both PyTorch and JAX. (#2343).
- [Build] Added
NVTE_BUILD_USE_NVIDIA_WHEELSenvironment variable to allow building TE using CUDA headers from PyPI NVIDIA wheels instead of a system CUDA installation. (#2623) - [C] Enabled deterministic FP8 fused attention on Blackwell (SM100) GPUs. (#2621)
- [C] Updated cuBLASMp integration to version 0.8.0, replacing the nvshmem dependency with NCCL-based symmetric memory. (#2661)
- [C] Added MXFP8 quantization kernels for grouped tensors used in MoE, with fused scale-factor swizzling for improved performance. (#2586, #2630)
- [C] Added NVFP4 quantization kernels for grouped tensors used in MoE models. (#2655)
- [C] Reduced cuDNN graph recompilations in THD fused attention by rounding large batch sizes to 512-element increments. (#2653)
- [C] Added
sqrtsoftplusscoring function to the fused MoE router and improved router kernel performance on Blackwell GPUs. (#2633, #2683) - [PyTorch] Introduced
GroupedTensor, enabling MoE expert weights to be stored as a single contiguous allocation while remaining individually addressable. (#2654) - [PyTorch] Added fusible
GroupedLinearandScaledSwiGLUops for building fully fused MoE grouped MLP pipelines. (#2664) - [PyTorch] Added
register_forward_fusionandregister_backward_fusionAPIs, allowing users to define and register custom operator fusion patterns. (#2597) - [PyTorch] Added
get_backward_dw_paramsAPI to TE modules, fixing weight gradient hook management when using wgrad CUDA Graphs with Megatron-LM. (#2614) - [PyTorch] Fixed fused attention bias dimension handling and extended
dbiassupport to additional bias shapes (b1ss,bhss,11ss,111s). (#2537) - [PyTorch] Reduced peak memory usage in fused Adam optimizer by fusing BF16 momentum scaling directly into CUDA kernels, also enabling CUDA Graph capture for this path. (#2632)
- [PyTorch] Added the sigmoid-gated GLU activation (
activation="glu") toLayerNormMLPandTransformerLayer. (#2656) - [PyTorch] Extended debug statistics tracking to NVFP4 quantization (underflow and MSE metrics), and gracefully skipped stat logging for layers not using quantization. (#2296, #2652)
- [PyTorch] Fixed CUDA Graph capture for Megatron-Core vision encoder models. (#2657)
- [JAX] Added experimental
inspect_arraydebugging utility for dumping tensor snapshots during multi-GPU execution. (#2651) - [JAX] Fixed MoE permutation to correctly mask padding tokens and handle tensor sizes under expert parallelism. (#2672)
- [JAX] MoE permutation now always returns
tokens_per_expert, required for ragged all-to-all communication in expert parallelism. (#2613)
Fixed Issues
- [C] Fixed incorrect results from the
exp2f_rcpfast-math helper when inputs are NaN or have biased exponent 254. (#2647) - [C] Fixed a race condition in Randomized Hadamard Transform amax kernels where a missing memory fence could cause incorrect amax values. (#2695)
- [PyTorch] Fixed the TE Llama example to work with HuggingFace Transformers 4.57+, which changed decoder layer output conventions. (#2572)
- [Build] Fixed
TypeErrorduring build when NCCL is installed from PyPI as a namespace package without a__file__attribute. (#2580) - [Build] Fixed
ModuleNotFoundErrorwhen installing from cached source distributions (e.g., viauv) by includingbuild_toolsinMANIFEST.in. (#2684)
Breaking Changes in This Release
- [C] Removed the deprecated packed fused attention C APIs (
nvte_fused_attn_{fwd,bwd}_{qkvpacked,kvpacked}); users must migrate to the non-packed API variants. (#2696) - Versions of cuBLASMp prior to 0.8.0 are no longer supported.
Deprecated Features
No features deprecated in this release.
v2.12
Transformer Engine v2.12 Release Notes
Key Features and Enhancements
- Made miscellaneous improvements and fixes to the documentation.
- [C] Improved performance of NVFP4 quantization kernels. (#2412)
- [C] Documented environment variables. (#2552)
- [PyTorch] Added fused permute+pad and unpermute+unpad operations for FP8 optimization. (#1921)
- [PyTorch] Improved the performance in CPU-limited scenarios.
- [PyTorch] Added support for Sliding Window Attention (left, right) with fused attention. (#2477)
- [PyTorch] Improved the performance of MXFP8 and NVFP4 by fusing the swizzling into the quantization (#2486)
- [PyTorch] Added cudagraph support for activation recomputation. (#2518)
- [JAX] Added a tutorial for integrating TE/JAX quantization into existing frameworks. (#2423)
- [JAX] Added custom partitioning for permutation primitives. (#2591)
Fixed Issues
- [C] Fixed SM120 compilation with CUDA 12. (#2482)
- [C] Fixed overflow in padding and unpadding kernels. (#2548)
- [C] Fixed a numerical issue in
sort_chunks_by_index. (#2566) - [C] Fixed a numerical issue in swizzling blockwise E8 scales. (#2589)
- [PyTorch] Fixed an AttributeError issue when checkpointing the model with MXFP8 parameters. (#2427)
- [PyTorch] Fixed cross-entropy loss calculation when some tokens are ignored. (#2476)
- [PyTorch] Fixed
Float8Tensor.contiguousautograd support. (#2533) - [PyTorch] Fixed multiple CPU offloading issues. (#2535)
- [PyTorch] Fixed uninitialized
permuted_scalevalues. (#2547) - [PyTorch] Fixed FP8 quantization for the second MLP in
LayerNormMLP. (#2577) - [PyTorch] Fixed ONNX tests and added FP8 attention export support. (#2598)
- [JAX] Removed unused TE DPA dtype handling to improve cuDNN backend dtype detection. (#2485)
- [JAX] Fixed segment-position calculation from segment IDs in
SequenceDescriptorclass. (#2523) - [JAX] Fixed bugs in permutation custom partitioning. (#2617)
- [JAX] Fixed issue in encoder and MNIST examples due to dataset path moving. (#2625)
Breaking Changes in This Release
No breaking changes in this release.
Deprecated Features
No features deprecated in this release.
v2.11
Transformer Engine v2.11 Release Notes
Key Features and Enhancements
- [PyTorch] Enabled the reference Current Scaling recipe for FP8 training. (#2368)
- [PyTorch] Improved Random Hadamard Transform (RHT) device tensor caching to reduce memory allocations and improve performance for NVFP4 quantization. (#2395)
- [PyTorch] Implemented selective activation checkpointing for LayerNormMLP module (#2311)
- [C, PyTorch, JAX] Improved performance of MXFP8 quantization. (#2062)
- [C, PyTorch] Improved performance of NVFP4 quantization. (#2351)
- [PyTorch] Improved FSDP2 all-gather performance and added support for FusedAdam optimizer with FSDP2. (#2370)
- [PyTorch] Extended debug tools to support GroupedLinear layers. (#1953)
- [JAX] Added Triton kernel bindings for JAX, enabling custom Triton kernels in JAX workflows. (#2437)
- [C] Introduced experimental NVTEGroupedTensor class and helper functions. (#2388)
- [C, PyTorch, JAX] Added FP8 support for primary weights in MXFP8 format with partial casting and amax calculations. (#2055)
- [JAX] Added support for context parallelism (CP) for THD format and sliding window attention (SWA) using all-gather (AG), striped load balancing with stripe size greater than 1. (#2379)
- [JAX] Implemented JAX primitives for token permutation operations on single GPU for mixture-of-experts routing. (#2473)
- [PyTorch] Added THD format support for max_logit clipping and MuonClip gradient clipping operations. (#2480)
Fixed Issues
- [PyTorch] Fixed a numerical issue when noncontiguous tensor was passed to cross_entropy backward pass. (#2402)
- [PyTorch] Fixed CUDA graph execution order for backward weight gradient computation when using chunked layers. (#2376)
- [C] Fixed runtime library loading logic to properly handle missing dependencies and load order. (#2297)
- [Jax] Removed use of scan loop as the default for ring attention due for improved performance (#2503).
Breaking Changes in This Release
No breaking changes in this release.
Deprecated Features
No features deprecated in this release.
v2.10
Release v2.10
Key Features and Enhancements
- [PyTorch] Added support for the NVFP4 training recipe for the
GroupedLinearmodule. - [PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
- [PyTorch] Added support for CUDA graphs when using
delay_wgrad_compute. - [PyTorch] Expanded debug tools to support more statistics.
- [PyTorch] Reduced the overhead of using debug tools.
- [PyTorch] Added support for clamped SwiGLU in the
TransformerLayermodule. - [PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a
keep_columnwiseparameter tocast_master_weights_to_fp8and related helper functions. - [PyTorch] Added a
resetinterface tomake_graphed_callablesthat clears internal CUDA graphs before distributed process group cleanup, preventing hangs. - [PyTorch] Added support for FSDP2 with quantized weights.
- [PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
- [PyTorch] Integrated Flash Attention's
num_splitsparameters into the attention backend. - [PyTorch] Made various improvements to mitigate CPU overhead, especially for the
GroupedLinearmodule. - [C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that
start_positionscould only be used withcp_size=1(context parallelism disabled). - [Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
- [Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
- [Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
- [Jax] Added support for checkpointing quantization operations in JAX.
- [Jax] Added support for sink attention.
- [Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).
Fixed Issues
- Fixed an occasional crash when loading cuDNN library during runtime.
- [C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
- [C] Fixed a numerical error in the amax computation in normalization kernels.
- [PyTorch] Fixed a crash in the permute kernel when using
tritonv3.5. - [PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
- [PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
- [Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
- [Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
- [Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
- [PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
- [Jax] Default value for
intermediate_dropoutchanged from 0.1 to 0.0. - [Jax] Default value for
return_layernorm_outputchanged fromTruetoFalse. - [Jax] Default activation changed from ReLU to GeLU.
- [Jax] Default input type for
DotProductAttentionis changed to BSHD.
Deprecated Features
No features are deprecated in this release.
v2.9
Release v2.9
Key Features and Enhancements
- [PyTorch][Jax] Introduced recipe agnostic functions and APIs in order to generalize to non-FP8 recipes. See Deprecated Features section for a comprehensive list of affected APIs.
- [C][PyTorch][Jax] Added support for the clamped SwiGLU activation function.
- [C] Added support for precompiled wheels for cuda13 via PyPI.
- [PyTorch] Added support for custom training recipes in the
autocastcontext. Transformer Engine quantizers, quantized tensors classes as well as storage dataclasses are now a part of the public API. - [PyTorch] Added CPU offload support for all attention layouts.
- [PyTorch] Added support for the FP8 block scaling recipe (as used in the DeepSeek v3 Technical Report) on NVIDIA Blackwell architecture (SM100 family).
- [PyTorch] Added support for gradient accumulation fusion when using FSDP.
- [PyTorch] Added support for CPU offloading when using
GroupedLinearwith distributed optimizer. - [PyTorch] Exposed as public API utility functions:
is_fp8_available,is_mxfp8_available,is_fp8_block_scaling_available,is_nvfp4_available,is_bf16_available,get_cudnn_version,get_device_compute_capability, andget_default_recipe. - [PyTorch] Added
max_logitsupport for the MuonClip optimizer. - [PyTorch][Jax] Improved the logic for selecting the attention backend, addressing various unsupported cases and preventing errors.
- [Jax] Added support for the NVFP4 training recipe.
- [Jax] Improved the performance of the current scaling recipes by enabling fused amax calculation in normalization and activation kernels.
- [Jax] Added support for bottom right causal mask for THD attention.
- Improved documentation and tutorials for the NVFP4 recipe.
Fixed Issues
- [Jax] Fixed a crash when using Context Parallelism with ring attention.
- [Jax] Fixed an issue with incorrect sharding when
get_all_mesh_axesis used. - [Jax] Fixed a numerical issue when using bias along with Tensor Parallelism.
- [PyTorch] Fixed an integer overflow issue in the triton permute kernel.
- [PyTorch] Fixed the known issue from
release_v2.8which resulted in worse performance for the FP8 current scaling recipe. - Fixed a build issue when cuDNN is installed into a custom location or python virtual environment.
Known Issues in This Release
- [C][PyTorch] The cuDNN attention backend produces NaNs in the forward pass for cases using a non-causal mask with cuDNN 9.13 and cuDNN 9.14. As a workaround, please set the
NVTE_FUSED_ATTNenvironment variable to 0 when using this configuration. - [C][PyTorch] The backward pass of cuDNN attention is incompatible with CUDA graphs for BSHD inputs where the sequence (S) dimension is not divisible by 128 when used with a non-padding mask. As a workaround, please set the
NVTE_FUSED_ATTNenvironment variable to 0 when using this configuration.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
- [PyTorch] The function
fp8_autocastis deprecated in favor ofautocast. The newautocastfunction uses argumentsrecipeandamax_reduction_groupinstead offp8_recipeandfp8_grouprespectively.
[PyTorch] The functionfp8_model_initis deprecated in favor ofquantized_model_init.
[PyTorch] The argumentsfp8_enabled,fp8_calibrating,fp8_recipe,fp8_group, andfp8_weight_cachingin functionmake_graphed_callablesare deprecated in favor ofenabled,calibrating,recipe,amax_reduction_group, andcache_quantized_paramsrespectively. - [Jax] The function
fp8_autocastis deprecated in favor ofautocast.
Miscellaneous:
None