Skip to content

Releases: NVIDIA/TransformerEngine

v2.16.1

26 Jun 01:05

Choose a tag to compare

Transformer Engine v2.16.1 Patch Release Notes

Fixed Issues

  • [PyTorch] Fixed an issue where the "extra state" of a checkpoint was unpickled even if the quantization recipe is stateless. To enhance security, add an explicit guard via environment variable in order to load the "extra state" portion of the checkpoints for stateful recipes. (#3123).

v2.16

09 Jun 01:15

Choose a tag to compare

Transformer Engine v2.16 Release Notes

Key Features and Enhancements

  • [Common] Improved the performance of the split-overlap reduce-scatter GEMMs. (#2056)
  • [Common] Improved the fused MoE auxiliary loss kernel performance for models with a large number of experts. (#2758)
  • [Common] Optimized MXFP8 and NVFP4 dequantize kernels for improved performance. (#2865)
  • [Common] Improved performance of the MXFP8 quantization kernels. (#2958)
  • [PyTorch] Added pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) attention. (#2596)
  • [PyTorch] Added role-based custom quantization control, enabling recipes to target specific modules and tensor types. (#2620)
  • [PyTorch] Added end-to-end Mixtral MoE examples showing TE GroupedLinear integration with HuggingFace models for BF16 and FP8 training. (#2642)
  • [PyTorch] Increased performance of the CPU activation offloading path in some cases (#2793)
  • [PyTorch] Reduced the CPU overhead in the GroupedLinear module and operation (#2900) (#2957) (#2666)
  • [PyTorch] Added CUDA Graph capture support for GroupedLinear and grouped MoE operations on supported configurations. (#2923)
  • [PyTorch] Added FlashAttention 4 support for attention head dimension 256. (#2932)
  • [JAX] Improved MoE permutation kernel performance. (#2975)
  • [JAX] Improved JAX tutorial documentation with updated examples and guidance. (#2976)
  • [Common, PyTorch] Added bias and dbias support for GroupedLinear layers. (#2885)
  • [Common, PyTorch] Added variable grouped swizzle support for flexible grouped tensor memory layouts. (#2914)
  • [Common, PyTorch] Implemented a row-scaled NVFP4 forward propagation recipe. (#2931)
  • [Common, PyTorch] Expanded grouped GEMM support with NVFP4 on Blackwell and FP8 block scaling on Hopper. (#2971)
  • [Common, JAX] Added a top-k operation for faster MoE routing. (#2890)
  • [Common, JAX] Enabled the cuDNN fused attention backend for no-mask bidirectional sliding-window attention. (#2961)

Fixed Issues

  • [PyTorch] Fixed variable-length attention cache reuse across devices and inference/training modes. (#2728)
  • [PyTorch] Fixed FSDP2 memory leaks for FP8 weight workspaces and transpose caches. (#2805)
  • [PyTorch] Fixed TE fuser behavior in torch.no_grad() paths by avoiding invalid gradient-flag updates on non-leaf tensors. (#2919)
  • [PyTorch] Fixed distributed checkpoint loading for FSDP2 for models initialized with QuantizedModelInit. (#2974)
  • [Common, PyTorch] Fixed cuBLAS grouped GEMM when weight dimensions are not divisible by 128. (#2954)
  • [Common, PyTorch] Fixed int32 overflow and -1 sentinel value handling in moe_permute. (#2907)
  • [Common, PyTorch] Fixed context-parallel FlashAttention output handling when FA3 is installed without FA2.(#2825)
  • [Common, PyTorch] Disabled RHT quantization fusion on unsupported GPU architectures to avoid launch failures. (#2968)
  • [PyTorch] Fixed a crash coming from GroupedLinear weight-gradient allocation. (#3049)

Breaking Changes in This Release

  • [Common, PyTorch] The original FP8 delayed-scaling fused attention path has been removed. FP8 attention now uses the current cuDNN-backed implementation. (#2959)
  • [Common, PyTorch, JAX] Removed the legacy f16_max512 fused-attention backend. BF16/FP16 attention is routed through the maintained arbitrary-sequence backend, but explicit selections of the old backend must be updated. (#2949)

Deprecated Features

There are no deprecated features in this release.

v2.15

13 May 04:39

Choose a tag to compare

Transformer Engine v2.15 Release Notes

Key Features and Enhancements

  • [PyTorch] Added support for Flash Attention 4. (#2432)
  • [PyTorch] Added support for MXFP8 attention. (#2719)
  • [PyTorch] Added support for QGeGLU activation both in te.ops and the fused grouped MLP path using GEMM + activation fusion. (#2855)
  • [PyTorch] Added support for per-token bias probability scaling both in te.ops and the fused grouped MLP path using GEMM + activation fusion. (#2864)
  • [PyTorch] Added support for NVFP4 weight quantization in the fused Adam optimizer. (#2797)
  • [PyTorch, Common] Added triton kernels to support mHC (Manifold-Constrained Hyper-Connections). (#2790)
  • [PyTorch, Common] Added support for dequantizing MXFP8 grouped tensors. (#2722)
  • [Common] Added support for unswizzling scaling factors. (#2837,#2732)
  • [PyTorch] Added Newton–Schulz orthogonalization via cuSOLVERMp for distributed orthogonalization workloads. (#2706)
  • [PyTorch] Added an NVTE_BACKWARD_OVERRIDE=high_precision|dequantized environment variable to control backward precision behavior. (#2644)
  • [PyTorch] Added a feature to debug tools to allow tensor dumps before and after quantization for numerical debugging. (#2645)
  • [PyTorch] Optimized FP8 block-scaling AllGather for FSDP2 to reduce communication overhead. (#2789)
  • [PyTorch] Added an example demonstrating high-precision weight initialization with fully_shard. (#2785)
  • [PyTorch] Expanded fused grouped MLP support via te.ops by lowering the weight dimension requirements to being divisible by 64 (previously 256). (#2856)
  • [PyTorch] Added torch.compile support for the MoE permute utility functions. (#2686)
  • [Common, PyTorch] Improved the performance of NVFP4 quantization by refactoring the amax compute kernel. (#2820)
  • [JAX] Reduced THD seqlen and offset computation from O(T·T) memory down to O(T) for long sequences. (#2522)
  • [JAX] Added MXFP8 grouped quantize + GEMM support. (#2763)

Fixed Issues

  • [PyTorch] Fixed a numerical bug where stale columnwise weight data would be used for post-validation training steps. (#2929)
  • [PyTorch] Fixed redundant memory usage when using NVFP4 parameters. (#2834)
  • [JAX] Fixed the JAX extension build with NVTE_UB_WITH_MPI=1. (#2835)
  • [Common] Fixed a numerical bug for the MoE fused router for large top-K and expert counts. (#2821)
  • [Common] Fixed an illegal memory access in register_user_buffer_collective on Ampere (and older) GPUs when using user buffers for COMM-GEMM overlap. (#2859)
  • [Build] Fixed a build crash when compiling from source with NVTE_CUDA_ARCHS=120. (#2832)

Known issues

  • [PyTorch] When building a grouped MLP module via te.ops.Sequential in order to use the GEMM + activation fusion, the kernel may produce non-deterministic results in the single grouped-weight case (i.e., when the environment variable NVTE_GROUPED_LINEAR_SINGLE_PARAM and the corresponding module argument single_grouped_weight is set).
  • [PyTorch] Enabling fused grouped MLP via te.ops requires cudnn-frontend library version 1.23.0. In case of issues please ensure that the right version of CuTeDSL is correctly installed:
python -m pip uninstall -y \
  cutlass \
  nvidia-cutlass \
  nvidia-cutlass-dsl \
  nvidia-cutlass-dsl-libs-base \
  nvidia-cutlass-dsl-libs-cu13 \
  nvidia-cudnn-frontend
python -m pip install -U pip setuptools wheel
python -m pip install --no-cache-dir "nvidia-cutlass-dsl[cu13]==4.4.1"
python -m pip install --no-cache-dir "nvidia-cudnn-frontend[cutedsl]==1.23.0"

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v2.14.1

24 Apr 22:30

Choose a tag to compare

Transformer Engine v2.14.1 Patch Release Notes

Fixed Issues

  • [All] Fixed the issue where the MXFP8 quantization + dbias fusion could nondeterministically produce the wrong results (#2921).

v2.14

21 Apr 21:57

Choose a tag to compare

Transformer Engine v2.14 Release Notes

Key Features and Enhancements

  • [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
  • [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
  • [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
  • [PyTorch] Added support for a single-parameter GroupedLinear configuration, where the weights of all experts are stored in a single parameter, which reduces CPU overheads. (#2731)
  • [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter GroupedLinear. (#2761)
  • [PyTorch] Extended the fused attention API to optionally return softmax Stats always and Max when return_max_logit=True, exposing more cuDNN intermediates to users. (#2677)
  • [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
  • [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in FusedAdam. (#2753)
  • [PyTorch] Added CUDA graph-compatible multi_tensor_scale_tensor API in the optimizer. (#2594)
  • [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
  • [PyTorch] Added support for non-FP32 params_dtype when using QK-normalization. (#2718)
  • [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
  • [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
  • [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
  • [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
  • [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
  • [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
  • [C] Enabled the fused RMSNorm dLN + add backward path through cuDNN for faster fused-residual normalization. (#2778)
  • [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
  • [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
  • [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
  • [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
  • [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
  • [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
  • [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
  • [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
  • [Documentation] Added documentation for the operator fuser API. (#2447)
  • [PyTorch, Documentation] Added end-to-end examples for fused_adam, quantized_model_init, and FSDP2 usage. (#2698) (#2662)

Fixed Issues

  • [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): when model parameters are DTensors, ensure optimizer states are also DTensors for correct sharded checkpoints. (#2795)
  • [PyTorch] Fixed async DCP checkpointing for Float8Tensor parameters. (#2721)
  • [PyTorch] Fixed the issue with cross_entropy_forward producing wrong answers for non-contiguous logits. (#2746)
  • [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
  • [PyTorch] Fixed a precision-debug-tools crash when tp_group=None. (#2733)
  • [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
  • [PyTorch] Fixed the initialization of the learnable softmax_offset parameter in DotProductAttention to zero-initialization. (#2694)
  • [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
  • [PyTorch] Added a clear error when constructing LayerNormLinear with row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688)
  • [JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)
  • [JAX] Fixed the assertion error when using from_segment_ids_and_pos() with vmap. (#2692)
  • [JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)
  • [JAX] Changed the dtype of the intermediate-result aval in fused_topk_and_score_function_fwd to fp32 to avoid precision loss. (#2752)
  • [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)
  • [C/PyTorch] Fixed score normalization in fused_score_for_moe_aux_loss when topk == 1. (#2720)
  • [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)

Breaking Changes in This Release

  • [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)

Deprecated Features

There are no deprecated features in this release.

v2.13

31 Mar 18:32

Choose a tag to compare

Transformer Engine v2.13 Release Notes

Key Features and Enhancements

  • Added detailed documentation for low precision training with Transformer Engine, covering FP8, MXFP8, NVFP4, and other quantization recipes with examples for both PyTorch and JAX. (#2343).
  • [Build] Added NVTE_BUILD_USE_NVIDIA_WHEELS environment variable to allow building TE using CUDA headers from PyPI NVIDIA wheels instead of a system CUDA installation. (#2623)
  • [C] Enabled deterministic FP8 fused attention on Blackwell (SM100) GPUs. (#2621)
  • [C] Updated cuBLASMp integration to version 0.8.0, replacing the nvshmem dependency with NCCL-based symmetric memory. (#2661)
  • [C] Added MXFP8 quantization kernels for grouped tensors used in MoE, with fused scale-factor swizzling for improved performance. (#2586, #2630)
  • [C] Added NVFP4 quantization kernels for grouped tensors used in MoE models. (#2655)
  • [C] Reduced cuDNN graph recompilations in THD fused attention by rounding large batch sizes to 512-element increments. (#2653)
  • [C] Added sqrtsoftplus scoring function to the fused MoE router and improved router kernel performance on Blackwell GPUs. (#2633, #2683)
  • [PyTorch] Introduced GroupedTensor, enabling MoE expert weights to be stored as a single contiguous allocation while remaining individually addressable. (#2654)
  • [PyTorch] Added fusible GroupedLinear and ScaledSwiGLU ops for building fully fused MoE grouped MLP pipelines. (#2664)
  • [PyTorch] Added register_forward_fusion and register_backward_fusion APIs, allowing users to define and register custom operator fusion patterns. (#2597)
  • [PyTorch] Added get_backward_dw_params API to TE modules, fixing weight gradient hook management when using wgrad CUDA Graphs with Megatron-LM. (#2614)
  • [PyTorch] Fixed fused attention bias dimension handling and extended dbias support to additional bias shapes (b1ss, bhss, 11ss, 111s). (#2537)
  • [PyTorch] Reduced peak memory usage in fused Adam optimizer by fusing BF16 momentum scaling directly into CUDA kernels, also enabling CUDA Graph capture for this path. (#2632)
  • [PyTorch] Added the sigmoid-gated GLU activation (activation="glu") to LayerNormMLP and TransformerLayer. (#2656)
  • [PyTorch] Extended debug statistics tracking to NVFP4 quantization (underflow and MSE metrics), and gracefully skipped stat logging for layers not using quantization. (#2296, #2652)
  • [PyTorch] Fixed CUDA Graph capture for Megatron-Core vision encoder models. (#2657)
  • [JAX] Added experimental inspect_array debugging utility for dumping tensor snapshots during multi-GPU execution. (#2651)
  • [JAX] Fixed MoE permutation to correctly mask padding tokens and handle tensor sizes under expert parallelism. (#2672)
  • [JAX] MoE permutation now always returns tokens_per_expert, required for ragged all-to-all communication in expert parallelism. (#2613)

Fixed Issues

  • [C] Fixed incorrect results from the exp2f_rcp fast-math helper when inputs are NaN or have biased exponent 254. (#2647)
  • [C] Fixed a race condition in Randomized Hadamard Transform amax kernels where a missing memory fence could cause incorrect amax values. (#2695)
  • [PyTorch] Fixed the TE Llama example to work with HuggingFace Transformers 4.57+, which changed decoder layer output conventions. (#2572)
  • [Build] Fixed TypeError during build when NCCL is installed from PyPI as a namespace package without a __file__ attribute. (#2580)
  • [Build] Fixed ModuleNotFoundError when installing from cached source distributions (e.g., via uv) by including build_tools in MANIFEST.in. (#2684)

Breaking Changes in This Release

  • [C] Removed the deprecated packed fused attention C APIs (nvte_fused_attn_{fwd,bwd}_{qkvpacked,kvpacked}); users must migrate to the non-packed API variants. (#2696)
  • Versions of cuBLASMp prior to 0.8.0 are no longer supported.

Deprecated Features

No features deprecated in this release.

v2.12

24 Feb 00:03

Choose a tag to compare

Transformer Engine v2.12 Release Notes

Key Features and Enhancements

  • Made miscellaneous improvements and fixes to the documentation.
  • [C] Improved performance of NVFP4 quantization kernels. (#2412)
  • [C] Documented environment variables. (#2552)
  • [PyTorch] Added fused permute+pad and unpermute+unpad operations for FP8 optimization. (#1921)
  • [PyTorch] Improved the performance in CPU-limited scenarios.
  • [PyTorch] Added support for Sliding Window Attention (left, right) with fused attention. (#2477)
  • [PyTorch] Improved the performance of MXFP8 and NVFP4 by fusing the swizzling into the quantization (#2486)
  • [PyTorch] Added cudagraph support for activation recomputation. (#2518)
  • [JAX] Added a tutorial for integrating TE/JAX quantization into existing frameworks. (#2423)
  • [JAX] Added custom partitioning for permutation primitives. (#2591)

Fixed Issues

  • [C] Fixed SM120 compilation with CUDA 12. (#2482)
  • [C] Fixed overflow in padding and unpadding kernels. (#2548)
  • [C] Fixed a numerical issue in sort_chunks_by_index. (#2566)
  • [C] Fixed a numerical issue in swizzling blockwise E8 scales. (#2589)
  • [PyTorch] Fixed an AttributeError issue when checkpointing the model with MXFP8 parameters. (#2427)
  • [PyTorch] Fixed cross-entropy loss calculation when some tokens are ignored. (#2476)
  • [PyTorch] Fixed Float8Tensor.contiguous autograd support. (#2533)
  • [PyTorch] Fixed multiple CPU offloading issues. (#2535)
  • [PyTorch] Fixed uninitialized permuted_scale values. (#2547)
  • [PyTorch] Fixed FP8 quantization for the second MLP in LayerNormMLP. (#2577)
  • [PyTorch] Fixed ONNX tests and added FP8 attention export support. (#2598)
  • [JAX] Removed unused TE DPA dtype handling to improve cuDNN backend dtype detection. (#2485)
  • [JAX] Fixed segment-position calculation from segment IDs in SequenceDescriptor class. (#2523)
  • [JAX] Fixed bugs in permutation custom partitioning. (#2617)
  • [JAX] Fixed issue in encoder and MNIST examples due to dataset path moving. (#2625)

Breaking Changes in This Release

No breaking changes in this release.

Deprecated Features

No features deprecated in this release.

v2.11

15 Jan 01:44

Choose a tag to compare

Transformer Engine v2.11 Release Notes

Key Features and Enhancements

  • [PyTorch] Enabled the reference Current Scaling recipe for FP8 training. (#2368)
  • [PyTorch] Improved Random Hadamard Transform (RHT) device tensor caching to reduce memory allocations and improve performance for NVFP4 quantization. (#2395)
  • [PyTorch] Implemented selective activation checkpointing for LayerNormMLP module (#2311)
  • [C, PyTorch, JAX] Improved performance of MXFP8 quantization. (#2062)
  • [C, PyTorch] Improved performance of NVFP4 quantization. (#2351)
  • [PyTorch] Improved FSDP2 all-gather performance and added support for FusedAdam optimizer with FSDP2. (#2370)
  • [PyTorch] Extended debug tools to support GroupedLinear layers. (#1953)
  • [JAX] Added Triton kernel bindings for JAX, enabling custom Triton kernels in JAX workflows. (#2437)
  • [C] Introduced experimental NVTEGroupedTensor class and helper functions. (#2388)
  • [C, PyTorch, JAX] Added FP8 support for primary weights in MXFP8 format with partial casting and amax calculations. (#2055)
  • [JAX] Added support for context parallelism (CP) for THD format and sliding window attention (SWA) using all-gather (AG), striped load balancing with stripe size greater than 1. (#2379)
  • [JAX] Implemented JAX primitives for token permutation operations on single GPU for mixture-of-experts routing. (#2473)
  • [PyTorch] Added THD format support for max_logit clipping and MuonClip gradient clipping operations. (#2480)

Fixed Issues

  • [PyTorch] Fixed a numerical issue when noncontiguous tensor was passed to cross_entropy backward pass. (#2402)
  • [PyTorch] Fixed CUDA graph execution order for backward weight gradient computation when using chunked layers. (#2376)
  • [C] Fixed runtime library loading logic to properly handle missing dependencies and load order. (#2297)
  • [Jax] Removed use of scan loop as the default for ring attention due for improved performance (#2503).

Breaking Changes in This Release

No breaking changes in this release.

Deprecated Features

No features deprecated in this release.

v2.10

11 Dec 21:29

Choose a tag to compare

Release v2.10

Key Features and Enhancements

  • [PyTorch] Added support for the NVFP4 training recipe for the GroupedLinear module.
  • [PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
  • [PyTorch] Added support for CUDA graphs when using delay_wgrad_compute.
  • [PyTorch] Expanded debug tools to support more statistics.
  • [PyTorch] Reduced the overhead of using debug tools.
  • [PyTorch] Added support for clamped SwiGLU in the TransformerLayer module.
  • [PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a keep_columnwise parameter to cast_master_weights_to_fp8 and related helper functions.
  • [PyTorch] Added a reset interface to make_graphed_callables that clears internal CUDA graphs before distributed process group cleanup, preventing hangs.
  • [PyTorch] Added support for FSDP2 with quantized weights.
  • [PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
  • [PyTorch] Integrated Flash Attention's num_splits parameters into the attention backend.
  • [PyTorch] Made various improvements to mitigate CPU overhead, especially for the GroupedLinear module.
  • [C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that start_positions could only be used with cp_size=1 (context parallelism disabled).
  • [Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
  • [Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
  • [Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
  • [Jax] Added support for checkpointing quantization operations in JAX.
  • [Jax] Added support for sink attention.
  • [Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).

Fixed Issues

  • Fixed an occasional crash when loading cuDNN library during runtime.
  • [C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
  • [C] Fixed a numerical error in the amax computation in normalization kernels.
  • [PyTorch] Fixed a crash in the permute kernel when using triton v3.5.
  • [PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
  • [PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
  • [Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
  • [Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
  • [Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
  • [PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

  • [Jax] Default value for intermediate_dropout changed from 0.1 to 0.0.
  • [Jax] Default value for return_layernorm_output changed from True to False.
  • [Jax] Default activation changed from ReLU to GeLU.
  • [Jax] Default input type for DotProductAttention is changed to BSHD.

Deprecated Features

No features are deprecated in this release.

v2.9

11 Nov 01:38

Choose a tag to compare

Release v2.9

Key Features and Enhancements

  • [PyTorch][Jax] Introduced recipe agnostic functions and APIs in order to generalize to non-FP8 recipes. See Deprecated Features section for a comprehensive list of affected APIs.
  • [C][PyTorch][Jax] Added support for the clamped SwiGLU activation function.
  • [C] Added support for precompiled wheels for cuda13 via PyPI.
  • [PyTorch] Added support for custom training recipes in the autocast context. Transformer Engine quantizers, quantized tensors classes as well as storage dataclasses are now a part of the public API.
  • [PyTorch] Added CPU offload support for all attention layouts.
  • [PyTorch] Added support for the FP8 block scaling recipe (as used in the DeepSeek v3 Technical Report) on NVIDIA Blackwell architecture (SM100 family).
  • [PyTorch] Added support for gradient accumulation fusion when using FSDP.
  • [PyTorch] Added support for CPU offloading when using GroupedLinear with distributed optimizer.
  • [PyTorch] Exposed as public API utility functions: is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, is_nvfp4_available, is_bf16_available, get_cudnn_version, get_device_compute_capability, and get_default_recipe.
  • [PyTorch] Added max_logit support for the MuonClip optimizer.
  • [PyTorch][Jax] Improved the logic for selecting the attention backend, addressing various unsupported cases and preventing errors.
  • [Jax] Added support for the NVFP4 training recipe.
  • [Jax] Improved the performance of the current scaling recipes by enabling fused amax calculation in normalization and activation kernels.
  • [Jax] Added support for bottom right causal mask for THD attention.
  • Improved documentation and tutorials for the NVFP4 recipe.

Fixed Issues

  • [Jax] Fixed a crash when using Context Parallelism with ring attention.
  • [Jax] Fixed an issue with incorrect sharding when get_all_mesh_axes is used.
  • [Jax] Fixed a numerical issue when using bias along with Tensor Parallelism.
  • [PyTorch] Fixed an integer overflow issue in the triton permute kernel.
  • [PyTorch] Fixed the known issue from release_v2.8 which resulted in worse performance for the FP8 current scaling recipe.
  • Fixed a build issue when cuDNN is installed into a custom location or python virtual environment.

Known Issues in This Release

  • [C][PyTorch] The cuDNN attention backend produces NaNs in the forward pass for cases using a non-causal mask with cuDNN 9.13 and cuDNN 9.14. As a workaround, please set the NVTE_FUSED_ATTN environment variable to 0 when using this configuration.
  • [C][PyTorch] The backward pass of cuDNN attention is incompatible with CUDA graphs for BSHD inputs where the sequence (S) dimension is not divisible by 128 when used with a non-padding mask. As a workaround, please set the NVTE_FUSED_ATTN environment variable to 0 when using this configuration.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [PyTorch] The function fp8_autocast is deprecated in favor of autocast. The new autocast function uses arguments recipe and amax_reduction_group instead of fp8_recipe and fp8_group respectively.
    [PyTorch] The function fp8_model_init is deprecated in favor of quantized_model_init.
    [PyTorch] The arguments fp8_enabled, fp8_calibrating, fp8_recipe, fp8_group, and fp8_weight_caching in function make_graphed_callables are deprecated in favor of enabled, calibrating, recipe, amax_reduction_group, and cache_quantized_params respectively.
  • [Jax] The function fp8_autocast is deprecated in favor of autocast.

Miscellaneous:
None