diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_parallelkittens.py deleted file mode 100755 index 04a4341..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_parallelkittens.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Strategy: -- **Device-Side Communication**: Instead of exchanging indices via NCCL and performing local lookups, we allocate the embedding shards symmetrically via `TKParallelTensor` so all GPUs have direct P2P memory access. A custom kernel directly reads the required remote embeddings over NVLink into the local output buffer. -- **Compute-Communication Overlap**: Peer NVLink loads are intrinsically executed asynchronously by the hardware. By vectorizing the memory operations and using warp-stride loops, the remote memory access latency is seamlessly hidden by the SM's warp scheduling, bypassing the latency penalties of `dist.all_to_all_single`. -- **Zero-Overhead Harness**: We use `get_or_create_parallel_tensor` to maintain the symmetric buffer. TK-native barriers guarantee safe copy-in and cross-GPU read semantics without repeatedly allocating or stalling via heavy PyTorch collectives. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (Embedding Lookup entrypoint + TK barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -struct globals { - static constexpr int NUM_DEVICES = 8; - nv_bfloat16* shards[NUM_DEVICES]; - const int64_t* indices; - nv_bfloat16* output; - int N; - int shard_size; - int embed_dim; -}; - -__global__ void lookup_kernel(globals G) { - // Each row of block handles one index lookup; x-dim provides vectorization - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if (idx >= G.N) return; - - int64_t global_id = G.indices[idx]; - int target_rank = global_id / G.shard_size; - int local_offset = global_id % G.shard_size; - - // Safety clamps - if (target_rank < 0) target_rank = 0; - if (target_rank >= globals::NUM_DEVICES) target_rank = globals::NUM_DEVICES - 1; - if (local_offset < 0) local_offset = 0; - if (local_offset >= G.shard_size) local_offset = G.shard_size - 1; - - const nv_bfloat16* src = G.shards[target_rank] + (size_t)local_offset * (size_t)G.embed_dim; - nv_bfloat16* dst = G.output + (size_t)idx * (size_t)G.embed_dim; - - int tid = threadIdx.x; - int stride = blockDim.x; - - // Check 16-byte alignment and compatible vectorization size - bool aligned = ((reinterpret_cast(src) % 16) == 0) && - ((reinterpret_cast(dst) % 16) == 0) && - ((G.embed_dim % 8) == 0); - - if (aligned) { - int d_vec = G.embed_dim / 8; // 8 bf16s per uint4 - const uint4* src_v = reinterpret_cast(src); - uint4* dst_v = reinterpret_cast(dst); - for (int i = tid; i < d_vec; i += stride) { - dst_v[i] = src_v[i]; - } - } else { - // Fallback for unaligned or irregularly sized dimensions - for (int i = tid; i < G.embed_dim; i += stride) { - dst[i] = src[i]; - } - } -} - -namespace lookup_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - kittens::barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - kittens::barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace lookup_barrier - -void entrypoint( - kittens::py::TKParallelTensor &shard_tk, - kittens::py::TKParallelTensor &barrier, - torch::Tensor indices, - torch::Tensor output, - int shard_size, - int embed_dim -) { - TORCH_CHECK(indices.is_cuda(), "indices must be on CUDA"); - TORCH_CHECK(output.is_cuda(), "output must be on CUDA"); - - globals G; - for (int i = 0; i < globals::NUM_DEVICES; i++) { - G.shards[i] = reinterpret_cast(shard_tk.ptrs_[i]); - } - G.indices = indices.data_ptr(); - G.output = reinterpret_cast(output.data_ptr()); - G.N = indices.numel(); - G.shard_size = shard_size; - G.embed_dim = embed_dim; - - lookup_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Barrier 1: ensure all peers have finished copying their local_shard into the symmetric shard_tk - kittens::py::launch_kernel(barrier_G); - - // Launch lookup (32 threads per index, 8 indices per block) - dim3 block(32, 8); - dim3 grid((G.N + block.y - 1) / block.y); - if (G.N > 0) { - lookup_kernel<<>>(G); - } - - // Barrier 2: ensure no peer overwrites its shard_tk (in next iteration) before others finish reading - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_embedding_lookup", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_embedding_lookup_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - indices: torch.Tensor, - local_shard: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - assert world_size == 8, f"This ThunderKittens kernel is built for NUM_DEVICES=8; got {world_size}" - - ext = _ensure_ext_jit() - - indices = indices.contiguous().to(torch.cuda.current_device()) - - shard_size, embed_dim = local_shard.shape - original_dtype = local_shard.dtype - - local_shard_bf16 = local_shard.to(torch.bfloat16).contiguous() - - # 1. Acquire symmetric memory for the table shards - shard_tk = get_or_create_parallel_tensor( - ext, (shard_size, embed_dim), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # 2. Copy our local shard to the symmetric memory. - # Must use copy_ to ensure writing into the underlying VMM buffer mapped for peers. - shard_tk.data_.copy_(local_shard_bf16) - - out_bf16 = torch.empty((indices.numel(), embed_dim), dtype=torch.bfloat16, device=indices.device) - - # 3. Launch unified kernel (barrier -> p2p lookup over NVLink -> barrier) - ext.tk_embedding_lookup( - shard_tk, - barrier_tk, - indices, - out_bf16, - shard_size, - embed_dim - ) - - return out_bf16.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_parallelkittens.py deleted file mode 100755 index 538b74b..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_parallelkittens.py +++ /dev/null @@ -1,267 +0,0 @@ -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: ThunderKittens Multicast All-Gather -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 8; // float4 -> 8 bf16 elements (16 bytes) - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout output; - size_t out_offset; - const bf16* input; - const int dev_idx; - const size_t numel_per_rank; - - __host__ inline dim3 grid() const { - return dim3((numel_per_rank + NUM_ELEMS_PER_BLOCK - 1) / NUM_ELEMS_PER_BLOCK); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t idx = globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - if (idx < G.numel_per_rank) { - // Load 16 bytes from local chunk - float4 tmp = reinterpret_cast(&G.input[idx])[0]; - - // Target index in the flat output buffer - const size_t out_idx = G.out_offset + G.dev_idx * G.numel_per_rank + idx; - - // Direct store to multicast pointer broadcasts the write across NVSwitch - reinterpret_cast(&G.output.mc_ptr[out_idx])[0] = tmp; - } -} - -} // namespace all_gather - -namespace all_gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - size_t out_offset, - uintptr_t input_ptr, - size_t numel_per_rank, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, barrier); - - all_gather::globals all_gather_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .out_offset = out_offset, - .input = reinterpret_cast(input_ptr), - .dev_idx = output.local_rank_, - .numel_per_rank = numel_per_rank - }; - - all_gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // 1. Barrier ensures all devices are ready before overlapping writes - kittens::py::launch_kernel(barrier_G); - - // 2. Multicast broadcast - kittens::py::launch_kernel(all_gather_G); - - // 3. Barrier ensures all data landed globally before the stream unblocks cuBLAS - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_pipelined_allgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(A_local: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B.is_cuda, "Inputs must be CUDA tensors" - - world_size = dist.get_world_size() - M, K_local = A_local.shape - K_global, N = B.shape - - ext = _ensure_ext_jit() - - # We partition M into chunks to overlap ThunderKittens broadcast with cuBLAS Matmul - NUM_CHUNKS = 4 - if M < NUM_CHUNKS * 8: - NUM_CHUNKS = 1 - - pad_M = 0 - align = NUM_CHUNKS * 8 # Guarantee 8 bf16s alignment per chunk for `float4` vectorized loads - if M % align != 0: - pad_M = align - (M % align) - padded_A = torch.zeros((M + pad_M, K_local), dtype=A_local.dtype, device=A_local.device) - padded_A[:M, :] = A_local - A_local = padded_A - - M_padded = M + pad_M - M_chunk = M_padded // NUM_CHUNKS - numel_per_rank = K_local * M_chunk - - # Pre-slice and transpose the blocks to ensure they're fully contiguous in memory - A_chunks_t = [] - for c in range(NUM_CHUNKS): - chunk = A_local[c * M_chunk : (c+1) * M_chunk, :] - A_chunks_t.append(chunk.transpose(0, 1).contiguous()) - - B_t = B.transpose(0, 1).contiguous() - C_t = torch.empty((N, M_padded), device=A_local.device, dtype=A_local.dtype) - - # Pre-allocate TK buffers for a double-buffered pipeline - NUM_BUFFERS = min(2, NUM_CHUNKS) - total_buffer_numel = NUM_BUFFERS * world_size * numel_per_rank - - # Allocates unified VMM mapping with NVSwitch Multicast capability - buffer_tk = get_or_create_parallel_tensor(ext, (total_buffer_numel,), A_local.dtype, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - stream_comm = torch.cuda.Stream() - stream_comp = torch.cuda.current_stream() - - events_comm_done = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - events_comp_done = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - - # Stream schedule loop - for c in range(NUM_CHUNKS): - buf_idx = c % NUM_BUFFERS - out_offset = buf_idx * world_size * numel_per_rank - - with torch.cuda.stream(stream_comm): - if c >= NUM_BUFFERS: - stream_comm.wait_event(events_comp_done[c - NUM_BUFFERS]) - - # Launch async TK PGL multicast - ext.tk_all_gather( - buffer_tk, - out_offset, - A_chunks_t[c].data_ptr(), - numel_per_rank, - barrier_tk - ) - events_comm_done[c].record(stream_comm) - - with torch.cuda.stream(stream_comp): - stream_comp.wait_event(events_comm_done[c]) - - # Form standard contiguous PyTorch view spanning the globally gathered block - gathered_data = buffer_tk.data_[out_offset : out_offset + world_size * numel_per_rank] - A_global_chunk = gathered_data.view(world_size, K_local, M_chunk).reshape(K_global, M_chunk) - - # Standard tensor core matmul overlapping with the subsequent iteration's communication - C_t_chunk = torch.matmul(B_t, A_global_chunk) - - C_t[:, c * M_chunk : (c+1) * M_chunk] = C_t_chunk - events_comp_done[c].record(stream_comp) - - # Sync pipeline - stream_comp.wait_stream(stream_comm) - - # Strip padding and final transpose - C = C_t[:, :M].transpose(0, 1).contiguous() - - return C \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_parallelkittens.py deleted file mode 100755 index 67324a2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_parallelkittens.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -Strategy: -- **Device-side Multicast All-Gather**: Replaces stock `dist.all_gather` and `torch.cat` with a custom ThunderKittens kernel that uses Hopper's `multimem::st`. Each rank broadcasts its `A_local` chunk directly to the correct strided columns of the globally shared `A_global` tensor over NVLink in a single operation. -- **Zero-Copy Assembly**: Eliminates host-driven slice concatenations by natively writing into the strided offsets of the continuous `A_global` buffer. -- **Maximized Compute Efficiency**: Operating at maximum NVLink bandwidth via hardware multicast, we seamlessly deliver a contiguous `A_global` to a single, monolithic `torch.matmul(A_global, B)`, perfectly retaining cuBLAS wave quantization efficiency without the memory bandwidth overhead of chunked accumulations. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - - using parallel_layout = pgl, NUM_DEVICES, true>; - parallel_layout A_global; - const bf16* local_ptr; - - int M; - int K_local; - int K_global; - int dev_idx; - - __host__ inline dim3 grid() const { - long long total_vecs = ((long long)M * K_local) / 2; - int threads = config::NUM_THREADS; - int blocks = (total_vecs + threads - 1) / threads; - if (blocks > 108 * 4) blocks = 108 * 4; // Saturate GPU - if (blocks < 1) blocks = 1; - return dim3(blocks); - } -}; - -__device__ inline void kernel(const globals &G) { - long long total_vecs = ((long long)G.M * G.K_local) / 2; - if (total_vecs == 0) return; - - for (long long vec_idx = blockIdx.x * blockDim.x + threadIdx.x; - vec_idx < total_vecs; - vec_idx += blockDim.x * gridDim.x) { - - long long elem_idx = vec_idx * 2; - long long row = elem_idx / G.K_local; - long long col = elem_idx % G.K_local; - - // Map correctly to the strided columns of the global tensor - long long dst_elem_idx = row * G.K_global + (G.dev_idx * G.K_local) + col; - - bf16_2 val = *(reinterpret_cast(&G.local_ptr[elem_idx])); - multimem::st(reinterpret_cast(&G.A_global.mc_ptr[dst_elem_idx]), val); - } -} - -} // namespace all_gather - -namespace barrier_ns { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace barrier_ns - -void tk_all_gather( - kittens::py::TKParallelTensor &A_global_tk, - torch::Tensor A_local, - kittens::py::TKParallelTensor &barrier -) { - int M = A_local.size(0); - int K_local = A_local.size(1); - - TORCH_CHECK(A_local.is_contiguous(), "A_local must be contiguous"); - TORCH_CHECK(K_local % 2 == 0, "K_local must be even for bf16_2 operations"); - - all_gather::globals ag_G { - .A_global = kittens::py::parallel_tensor_to_pgl(A_global_tk), - .local_ptr = reinterpret_cast(A_local.data_ptr()), - .M = M, - .K_local = K_local, - .K_global = K_local * all_gather::globals::NUM_DEVICES, - .dev_idx = A_global_tk.local_rank_ - }; - - barrier_ns::globals bar_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronize peers -> Hardware Multicast All-Gather -> Synchronize peers - kittens::py::launch_kernel(bar_G); - kittens::py::launch_kernel(ag_G); - kittens::py::launch_kernel(bar_G); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_gather", &tk_all_gather); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allgather_gemm_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(A_local: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B.is_cuda, "Inputs must be CUDA tensors" - - world_size = dist.get_world_size() - M, K_local = A_local.shape - K_global = world_size * K_local - - # cuBLAS / Tensor Cores strongly prefer even alignments - assert K_local % 2 == 0, "K_local must be even for vectorized operations." - assert B.shape[0] == K_global, f"B must have K dimension = world_size * K_local" - - ext = _ensure_ext_jit() - - original_dtype = A_local.dtype - A_local_bf16 = A_local.to(torch.bfloat16).contiguous() - B_bf16 = B.to(torch.bfloat16).contiguous() - - # Obtain or create cached TKParallelTensor mapping NVSwitch multimem - A_global_tk = get_or_create_parallel_tensor( - ext, (M, K_global), torch.bfloat16, multicast=True - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Lightning fast hardware-driven Multicast All-Gather - ext.tk_all_gather(A_global_tk, A_local_bf16, barrier_tk) - - # Retrieve identically-sized view of the globally mapped tensor - A_global = A_global_tk.data_.view(M, K_global) - - # High-efficiency monolithic GEMM using cuBLAS Tensor Cores - C = torch.matmul(A_global, B_bf16) - - # Match reference boundary logic - dist.barrier() - - return C.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_parallelkittens.py deleted file mode 100755 index d069f2f..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_parallelkittens.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -ThunderKittens Distributed GEMM with All-Reduce - -Strategy: -We optimize the GEMM + All-Reduce map-reduce pattern by pipelining local matrix -multiplications with device-side communication. Instead of computing the full -local GEMM and then launching a monolithic collective, we slice the M dimension -into chunks. The CPU submits the compute (cuBLAS) for chunk `i` and then asynchronously -dispatches a custom ThunderKittens multimem reduction for that chunk on a separate -communication stream. This hides the NVSwitch/NVLink network latency behind the dense -tensor-core math of the subsequent chunks, maximizing overlapping and throughput on Hopper. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (Pipelined all_reduce entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - int offset; - int chunk_numel; - - __host__ inline dim3 grid() const { - return dim3(chunk_numel / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_per_dev = G.chunk_numel / globals::NUM_DEVICES; - const size_t idx = G.offset + N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier, - int offset, - int chunk_numel -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(chunk_numel % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "chunk_numel must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_, - .offset = offset, - .chunk_numel = chunk_numel - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce_chunk", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - -_stream_comm = None -_events_compute = [] - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gemm_allreduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _get_comm_resources(chunks): - """Reuse CUDA streams and events to minimize per-iteration allocation overhead.""" - global _stream_comm, _events_compute - if _stream_comm is None: - _stream_comm = torch.cuda.Stream() - while len(_events_compute) < chunks: - _events_compute.append(torch.cuda.Event()) - return _stream_comm, _events_compute[:chunks] - - -@torch.no_grad() -def solution( - A_local: torch.Tensor, - B_local: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B_local.is_cuda, "Inputs must be CUDA tensors" - - world = dist.get_world_size() - ext = _ensure_ext_jit() - - original_dtype = A_local.dtype - M, K = A_local.shape - K_B, N = B_local.shape - assert K == K_B, f"A_local and B_local must have matching K dimension: {K} != {K_B}" - - A_bf16 = A_local.to(torch.bfloat16).contiguous() - B_bf16 = B_local.to(torch.bfloat16).contiguous() - - # Pipelining: overlap chunk N+1 compute with chunk N all-reduce. - # Fallback to single chunk for tiny payloads to avoid slicing overhead. - chunks = min(2, M) - if M * N < 512 * 512: - chunks = 1 - - M_per_chunk = (M + chunks - 1) // chunks - chunk_numel = M_per_chunk * N - aligned_chunk_numel = ((chunk_numel + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - total_elements = aligned_chunk_numel * chunks - - tensor_tk = get_or_create_parallel_tensor(ext, (total_elements,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - stream_compute = torch.cuda.current_stream() - stream_comm, events_compute_done = _get_comm_resources(chunks) - - # 1. Pipeline schedule: Compute GEMM chunk -> record event -> asynchronously trigger multimem reduce. - for i in range(chunks): - start_M = i * M_per_chunk - end_M = min(start_M + M_per_chunk, M) - actual_M = end_M - start_M - actual_numel = actual_M * N - offset = i * aligned_chunk_numel - - if actual_M > 0: - # We construct a view directly into the symmetric VMM parallel tensor memory - out_view = tensor_tk.data_[offset : offset + actual_numel].view(actual_M, N) - torch.matmul(A_bf16[start_M:end_M], B_bf16, out=out_view) - - # Zero out any padded tail to avoid incorporating garbage in the sum reduction - if aligned_chunk_numel > actual_numel: - tensor_tk.data_[offset + actual_numel : offset + aligned_chunk_numel].zero_() - - events_compute_done[i].record(stream_compute) - - # Offload synchronization & reduction collective to comm stream - with torch.cuda.stream(stream_comm): - stream_comm.wait_event(events_compute_done[i]) - if actual_M > 0: - ext.tk_all_reduce_chunk(tensor_tk, barrier_tk, offset, aligned_chunk_numel) - - stream_compute.wait_stream(stream_comm) - - # 2. Gather results from the VMM buffer - C = torch.empty((M, N), dtype=original_dtype, device=A_local.device) - for i in range(chunks): - start_M = i * M_per_chunk - end_M = min(start_M + M_per_chunk, M) - actual_M = end_M - start_M - actual_numel = actual_M * N - offset = i * aligned_chunk_numel - - if actual_M > 0: - C[start_M:end_M] = tensor_tk.data_[offset : offset + actual_numel].view(actual_M, N).to(original_dtype) - - return C \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_parallelkittens.py deleted file mode 100755 index e37c04b..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_parallelkittens.py +++ /dev/null @@ -1,299 +0,0 @@ -""" -Strategy: -- **Device-Side Communication & Overlap**: Instead of computing a full local GEMM shard and executing a host-driven `all_gather` collective, we fuse the GEMM and communication. The ThunderKittens kernel computes the local shard $C_{\text{local}} = A \times B_{\text{local}}$ in 128x128 tiles. As soon as a tile finishes computing via WGMMA, it uses asynchronous TMA stores to broadcast the result to its final offset in the symmetric `C` matrix on all peers over NVLink. -- **Pipelining**: Within the kernel, double-buffered TMA loads overlap with MMA compute. At the grid level, blocking on TMA stores natively overlaps with independent blocks executing on other SMs, maximizing compute and network saturation without extra buffers. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace gemm_allgather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 1; - static constexpr int NUM_WARPGROUPS = 1; - static constexpr int NUM_THREADS = NUM_WARPGROUPS * 128; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int BM = 128; - static constexpr int BN = 128; - static constexpr int BK = 64; - - using st_A = st_bf; - using st_B = st_bf; - using st_C = st_bf; - - using layout_A = pgl, NUM_DEVICES, false>; - using layout_B = pgl, NUM_DEVICES, false>; - using layout_C = pgl, NUM_DEVICES, false>; - - layout_A A; - layout_B B; - layout_C C; - - int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(B.cols() / BN, A.rows() / BM); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(2 * sizeof(st_A) + 2 * sizeof(st_B) + sizeof(st_C) + 2048); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - - globals::st_A (&smem_A)[2] = allocator.allocate(); - globals::st_B (&smem_B)[2] = allocator.allocate(); - globals::st_C &smem_C = allocator.allocate(); - - int i_n = blockIdx.x; - int i_m = blockIdx.y; - int K_blocks = G.A.cols() / globals::BK; - - __shared__ semaphore arrived_A[2]; - __shared__ semaphore arrived_B[2]; - - if (threadIdx.x == 0) { - init_semaphore(arrived_A[0], 0, 1); - init_semaphore(arrived_A[1], 0, 1); - init_semaphore(arrived_B[0], 0, 1); - init_semaphore(arrived_B[1], 0, 1); - } - __syncthreads(); - - int tic = 0, toc = 1; - - if (threadIdx.x == 0) { - tma::expect_bytes(arrived_A[tic], sizeof(globals::st_A)); - tma::expect_bytes(arrived_B[tic], sizeof(globals::st_B)); - tma::load_async(smem_A[tic], G.A[G.dev_idx], {0, 0, i_m, 0}, arrived_A[tic]); - tma::load_async(smem_B[tic], G.B[G.dev_idx], {0, 0, 0, i_n}, arrived_B[tic]); - } - - rt_bf accum; - zero(accum); - - for (int k = 0; k < K_blocks; k++, tic = toc, toc ^= 1) { - if (k < K_blocks - 1) { - if (threadIdx.x == 0) { - tma::expect_bytes(arrived_A[toc], sizeof(globals::st_A)); - tma::expect_bytes(arrived_B[toc], sizeof(globals::st_B)); - tma::load_async(smem_A[toc], G.A[G.dev_idx], {0, 0, i_m, k + 1}, arrived_A[toc]); - tma::load_async(smem_B[toc], G.B[G.dev_idx], {0, 0, k + 1, i_n}, arrived_B[toc]); - } - } - - int phase = k / 2; - wait(arrived_A[tic], phase); - wait(arrived_B[tic], phase); - - warpgroup::mma_AB(accum, smem_A[tic], smem_B[tic]); - } - - warpgroup::mma_async_wait(); - - // Write out the MAC result to shared memory - warpgroup::store(smem_C, accum); - __syncthreads(); - - // Broadcast the result to all peers directly into their symmetric memory blocks - if (threadIdx.x == 0) { - int N_local_blocks = G.B.cols() / globals::BN; - int dst_col_block = G.dev_idx * N_local_blocks + i_n; - - #pragma unroll - for(int d = 0; d < globals::NUM_DEVICES; d++) { - tma::store_async(G.C[d], smem_C, {0, 0, i_m, dst_col_block}); - } - asm volatile("cp.async.bulk.commit_group;"); - asm volatile("cp.async.bulk.wait_group 0;"); - } - __syncthreads(); -} - -} // namespace gemm_allgather - -namespace barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace barrier - -void entrypoint( - kittens::py::TKParallelTensor &A, - kittens::py::TKParallelTensor &B, - kittens::py::TKParallelTensor &C, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(A, B, C, barrier); - - gemm_allgather::globals G { - .A = kittens::py::parallel_tensor_to_pgl(A), - .B = kittens::py::parallel_tensor_to_pgl(B), - .C = kittens::py::parallel_tensor_to_pgl(C), - .dev_idx = A.local_rank_ - }; - - barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_gemm_allgather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gemmallgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A.is_cuda and B.is_cuda, "Inputs must be CUDA tensors" - - world_size = dist.get_world_size() - assert world_size == 8, "This ThunderKittens kernel is built for NUM_DEVICES=8" - - M, K = A.shape - K_B, N_local = B.shape - assert K == K_B, f"A and B must have matching K dimension: {K} != {K_B}" - - ext = _ensure_ext_jit() - original_dtype = A.dtype - - # Quantize blocks cleanly into ThunderKittens tile sizes - pad_M = (M + 127) // 128 * 128 - pad_K = (K + 63) // 64 * 64 - pad_N_local = (N_local + 127) // 128 * 128 - - # Simple hack to circumvent shape collision if `pad_M == pad_N_local` inside - # parallelkittens caching map. Adding an unused dummy tile column shifts - # uniquely without perturbing valid matrix bounds in memory. - if pad_M == pad_N_local: - pad_N_local += 128 - - pad_N_total = world_size * pad_N_local - - A_tk = get_or_create_parallel_tensor(ext, (1, 1, pad_M, pad_K), torch.bfloat16, multicast=False) - B_tk = get_or_create_parallel_tensor(ext, (1, 1, pad_K, pad_N_local), torch.bfloat16, multicast=False) - C_tk = get_or_create_parallel_tensor(ext, (1, 1, pad_M, pad_N_total), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Clear padded borders to bypass contamination and fill tensors - A_flat_len = pad_M * pad_K - A_tk.data_.reshape(-1)[:A_flat_len].view(pad_M, pad_K).zero_() - A_tk.data_.reshape(-1)[:A_flat_len].view(pad_M, pad_K)[:M, :K].copy_(A) - - B_flat_len = pad_K * pad_N_local - B_tk.data_.reshape(-1)[:B_flat_len].view(pad_K, pad_N_local).zero_() - B_tk.data_.reshape(-1)[:B_flat_len].view(pad_K, pad_N_local)[:K, :N_local].copy_(B) - - ext.tk_gemm_allgather(A_tk, B_tk, C_tk, barrier_tk) - - C_flat_len = pad_M * pad_N_total - C_out_full = C_tk.data_.reshape(-1)[:C_flat_len].view(pad_M, pad_N_total) - - if pad_N_local == N_local: - C_final = C_out_full[:M, :world_size * N_local].clone() - else: - chunks = [] - for i in range(world_size): - chunks.append(C_out_full[:M, i * pad_N_local : i * pad_N_local + N_local]) - C_final = torch.cat(chunks, dim=1) - - return C_final.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_parallelkittens.py deleted file mode 100755 index ea59a60..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_parallelkittens.py +++ /dev/null @@ -1,296 +0,0 @@ -""" -Strategy: -- Algorithmic Reduction: We eliminate the redundant global Matmuls and `reduce_scatter` by transforming the initial All-Gather into an All-to-All. Rank `r` only requests the `M_local` rows it needs for its sequence-parallel compute block, drastically reducing communication volume by 8x and compute by 8x. -- Custom TK All-to-All TMA: Data movement uses a custom ThunderKittens TMA kernel directly scattering input blocks into the destination's correctly-strided layout `[M_local, H]`, leveraging device-side UVA P2P without intermediate buffers or opaque PyTorch collectives. -- Compute-Communication Overlap: The `M_local` rows are split into chunks. A parallel CUDA stream runs the TK collective for the next chunk concurrently with PyTorch Tensor Core Matmuls and SiLU compute on the current chunk, completely hiding the reduced communication latency. -""" - -import os -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for All-to-All Gather via TMA -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace tk_all_to_all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int TILE_M = 16; - static constexpr int TILE_H = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - int m_offset_blocks; - int m_size_blocks; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(m_size_blocks * (input.cols() / globals::TILE_H) * NUM_DEVICES); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int dst_dev_idx = task_idx / (G.m_size_blocks * (G.input.cols() / globals::TILE_H)); - task_idx %= (G.m_size_blocks * (G.input.cols() / globals::TILE_H)); - int row_block_idx = G.m_offset_blocks + task_idx / (G.input.cols() / globals::TILE_H); - int col_block_idx = task_idx % (G.input.cols() / globals::TILE_H); - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // input shape: [W, 1, M_pad, H_pad] - tma::load_async(tile, G.input[G.dev_idx], {dst_dev_idx, 0, row_block_idx, col_block_idx}, arrived); - - wait(arrived, 0); - - // output shape: [1, 1, M_pad, W * H_pad] - int out_col_block_idx = G.dev_idx * (G.input.cols() / globals::TILE_H) + col_block_idx; - - tma::store_async(G.output[dst_dev_idx], tile, {0, 0, row_block_idx, out_col_block_idx}); -} - -} // namespace tk_all_to_all_gather - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int m_offset_blocks, - int m_size_blocks -) { - TORCH_CHECK(m_size_blocks > 0, "m_size_blocks must be positive"); - kittens::py::parallel_tensor_check(output, input); - - tk_all_to_all_gather::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .m_offset_blocks = m_offset_blocks, - .m_size_blocks = m_size_blocks, - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_to_all_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoallgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - x_local: torch.Tensor, - W1: torch.Tensor, - W2: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - world_size = dist.get_world_size() - assert world_size == NUM_DEVICES, f"This ThunderKittens kernel expects NUM_DEVICES={NUM_DEVICES}" - - ext = _ensure_ext_jit() - - M, H_local = x_local.shape - H, F_dim = W1.shape - - M_local = M // world_size - - original_dtype = x_local.dtype - x_local = x_local.to(torch.bfloat16).contiguous() - W1 = W1.to(torch.bfloat16) - W2 = W2.to(torch.bfloat16) - - y_local_full = torch.empty((M_local, H), dtype=torch.bfloat16, device=x_local.device) - - if M_local == 0: - dist.barrier() - return y_local_full.to(original_dtype) - - M_pad = ((M_local + 15) // 16) * 16 - H_pad = ((H_local + 127) // 128) * 128 - H_out_pad = world_size * H_pad - - # Shared pre-allocated UVA descriptors - input_tk = get_or_create_parallel_tensor( - ext, (world_size, 1, M_pad, H_pad), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, 1, M_pad, H_out_pad), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Prepare interleaved slices for TMA: rank `i`'s input requires sending sequence chunk `i` to rank `i` - x_local_view = x_local.view(world_size, M_local, H_local) - padded_in = torch.zeros(world_size, 1, M_pad, H_pad, dtype=torch.bfloat16, device=x_local.device) - padded_in[:, 0, :M_local, :H_local] = x_local_view - - in_numel = world_size * M_pad * H_pad - input_tk.data_.reshape(-1)[:in_numel].copy_(padded_in.reshape(-1)) - - # Chunk M_local logic for optimal communication overlap - m_blocks = M_pad // 16 - if m_blocks >= 2: - blocks_1 = m_blocks // 2 - blocks_2 = m_blocks - blocks_1 - chunks = [(0, blocks_1), (blocks_1, blocks_2)] - else: - chunks = [(0, m_blocks)] - - s_main = torch.cuda.current_stream() - s_comm = torch.cuda.Stream() - - # Kickoff the first All-to-All block transfer - offset_0, size_0 = chunks[0] - with torch.cuda.stream(s_comm): - ext.tk_all_to_all_gather(output_tk, input_tk, barrier_tk, offset_0, size_0) - - for i in range(len(chunks)): - offset, size = chunks[i] - - # Wait for this chunk's TMA P2P delivery to wrap up - s_main.wait_stream(s_comm) - - # Pre-issue the next chunk's comm immediately - if i + 1 < len(chunks): - next_offset, next_size = chunks[i+1] - with torch.cuda.stream(s_comm): - ext.tk_all_to_all_gather(output_tk, input_tk, barrier_tk, next_offset, next_size) - - valid_start = min(offset * 16, M_local) - valid_end = min((offset + size) * 16, M_local) - if valid_start >= valid_end: - continue - - # Extract the gathered valid elements from the TK memory footprint - out_numel = M_pad * H_out_pad - out_view = output_tk.data_.reshape(-1)[:out_numel].view(M_pad, world_size, H_pad) - x_chunk_pad = out_view[valid_start:valid_end, :, :] - - # Slicing the tensor removes the internal H padding natively - x_chunk = x_chunk_pad[:, :, :H_local].contiguous().view(-1, H) - - # Reduced compute step: instead of `[M, H] @ W1`, only `[M_local, H] @ W1` - z_chunk = torch.matmul(x_chunk, W1) - a_chunk = F.silu(z_chunk) - block_chunk = torch.matmul(a_chunk, W2) - - y_local_full[valid_start:valid_end, :] = block_chunk - - dist.barrier() - return y_local_full.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_parallelkittens.py deleted file mode 100755 index 493f318..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_parallelkittens.py +++ /dev/null @@ -1,278 +0,0 @@ -""" -ThunderKittens pipelined GEMM + reduce-scatter. - -Strategy: -- Overlap Compute and Comm: We partition the M dimension into world_size chunks. PyTorch GEMMs are issued sequentially on the default stream, while a custom ThunderKittens P2P kernel executes on a comm stream, overlapping Tensor Core math with NVLink memory transfers. -- Device-side Data Movement: We allocate a symmetric TKParallelTensor buffer. Each rank explicitly pulls its designated output chunk from all peers' memory using direct one-sided vectorized loads, summing inline and bypassing intermediate buffers. -- Stream-Native Barriers: Fine-grained device-side barriers (barrier_all) are launched per-chunk on the communication stream, synchronizing memory access seamlessly without stalling the host CPU. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: P2P Pull Reduce-Scatter -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace rs_chunk { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - // We treat the parallel tensors as 1D logical layouts and calculate offsets manually - using layout = pgl, NUM_DEVICES, false>; - - layout c_local; - layout c_partial; - barrier_t barrier; - - int dev_idx; - int chunk_idx; - int num_elems; - int padded_elems; - - __host__ inline dim3 grid() const { - return dim3((num_elems + config::NUM_THREADS * 8 - 1) / (config::NUM_THREADS * 8)); - } -}; - -__device__ inline void kernel(const globals& G) { - // Wait for all peers to finish the GEMM computation for this chunk - barrier_all(G.barrier, {0}, G.dev_idx); - - // Only the rank responsible for this chunk does the pull-and-reduce - if (G.dev_idx == G.chunk_idx) { - const int vec_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int base_idx = vec_idx * 8; - - if (base_idx < G.num_elems) { - float sum[8] = {0.0f}; - int valid = (G.num_elems - base_idx > 8) ? 8 : (G.num_elems - base_idx); - - #pragma unroll - for (int p = 0; p < globals::NUM_DEVICES; p++) { - bf16* peer_ptr = (bf16*)&G.c_partial[p](0,0,0,0); - - if (valid == 8) { - // Fast path: fully vectorized 16-byte load - uint4 vec = *(uint4*)(&peer_ptr[G.chunk_idx * G.padded_elems + base_idx]); - bf16* vals = (bf16*)&vec; - for(int i = 0; i < 8; ++i) { - sum[i] += __bfloat162float(vals[i]); - } - } else { - // Edge case path - for(int i = 0; i < valid; ++i) { - sum[i] += __bfloat162float(peer_ptr[G.chunk_idx * G.padded_elems + base_idx + i]); - } - } - } - - bf16* out_ptr = (bf16*)&G.c_local[G.dev_idx](0,0,0,0); - if (valid == 8) { - uint4 out_vec; - bf16* out_vals = (bf16*)&out_vec; - for(int i = 0; i < 8; ++i) { - out_vals[i] = __float2bfloat16(sum[i]); - } - *(uint4*)(&out_ptr[base_idx]) = out_vec; - } else { - for(int i = 0; i < valid; ++i) { - out_ptr[base_idx + i] = __float2bfloat16(sum[i]); - } - } - } - } -} - -} // namespace rs_chunk - -namespace sync_only { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 128; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - int dev_idx; - __host__ inline dim3 grid() const { return dim3(1); } -}; -__device__ inline void kernel(const globals& G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace sync_only - -void rs_entrypoint( - kittens::py::TKParallelTensor &c_local, - kittens::py::TKParallelTensor &c_partial, - kittens::py::TKParallelTensor &barrier, - int chunk_idx, - int num_elems, - int padded_elems -) { - kittens::py::parallel_tensor_check(c_local, c_partial, barrier); - - rs_chunk::globals G { - .c_local = kittens::py::parallel_tensor_to_pgl(c_local), - .c_partial = kittens::py::parallel_tensor_to_pgl(c_partial), - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = c_local.local_rank_, - .chunk_idx = chunk_idx, - .num_elems = num_elems, - .padded_elems = padded_elems - }; - - kittens::py::launch_kernel(G); -} - -void barrier_entrypoint( - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(barrier); - - sync_only::globals G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_reduce_scatter_chunk", &rs_entrypoint); - m.def("tk_barrier_only", &barrier_entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gemm_reducescatter_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(A_local: torch.Tensor, B_local: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B_local.is_cuda, "Inputs must be CUDA tensors" - - world_size = dist.get_world_size() - assert world_size == NUM_DEVICES, f"Expected {NUM_DEVICES} ranks for this compiled kernel" - - M, K_local = A_local.shape - K_B, N = B_local.shape - assert K_local == K_B, "A_local and B_local inner dims mismatch" - assert M % world_size == 0, "M must be divisible by world_size" - - M_local = M // world_size - num_elems = M_local * N - - # Pad alignment to 2048 elements for safety with flat VMM allocations - padded_elems = ((num_elems + 2047) // 2048) * 2048 - - ext = _ensure_ext_jit() - - # Cached TKParallelTensor buffers - C_partial_tk = get_or_create_parallel_tensor(ext, (world_size, padded_elems), torch.bfloat16, multicast=False) - C_local_tk = get_or_create_parallel_tensor(ext, (padded_elems,), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - gemm_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream() - gemm_events = [torch.cuda.Event() for _ in range(world_size)] - - c_partial_views = [] - for c in range(world_size): - v = C_partial_tk.data_[c, :num_elems].view(M_local, N) - c_partial_views.append(v) - - # Pipelined Compute and Communication - for c in range(world_size): - # 1. Compute chunk c into symmetric buffer directly - A_chunk = A_local[c * M_local : (c + 1) * M_local, :] - torch.matmul(A_chunk, B_local, out=c_partial_views[c]) - gemm_events[c].record(gemm_stream) - - # 2. Launch TK pull-kernel to overlap with the next GEMM - comm_stream.wait_event(gemm_events[c]) - with torch.cuda.stream(comm_stream): - ext.tk_reduce_scatter_chunk( - C_local_tk, C_partial_tk, barrier_tk, c, num_elems, padded_elems - ) - - # Launch final device-side sync on the comm stream. This ensures no peer returns and - # executes its next loop call (which could overwrite C_partial) before all reads conclude. - with torch.cuda.stream(comm_stream): - ext.tk_barrier_only(barrier_tk) - - gemm_stream.wait_stream(comm_stream) - - # Return local result from the symmetrical buffer - C_local = torch.empty((M_local, N), dtype=A_local.dtype, device=A_local.device) - C_local.copy_(C_local_tk.data_[:num_elems].view(M_local, N)) - - return C_local \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_parallelkittens.py deleted file mode 100755 index bdf9c86..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_parallelkittens.py +++ /dev/null @@ -1,308 +0,0 @@ -import os -import torch -import torch.distributed as dist -from typing import Tuple -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: Fused RoPE + PGL All-Gather -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace rope_all_gather { - -struct config { - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - // 1D layout is sufficient, we calculate multi-dimensional indices manually - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout q_out; - parallel_layout k_out; - const bf16* q_local; - const bf16* k_local; - const bf16* cos_local; - const bf16* sin_local; - int B, S_local, H, D; - int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((B * S_local * H * (D / 2) + config::NUM_THREADS - 1) / config::NUM_THREADS); - } -}; - -__device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = G.B * G.S_local * G.H * (G.D / 2); - if (idx >= total) return; - - // Each thread processes 2 contiguous feature elements (bf16_2) - int d_half = idx % (G.D / 2); - int d = d_half * 2; - int h = (idx / (G.D / 2)) % G.H; - int s = (idx / (G.D / 2 * G.H)) % G.S_local; - int b = idx / (G.D / 2 * G.H * G.S_local); - - // Compute exact memory bounds - int in_idx = b * (G.S_local * G.H * G.D) + s * (G.H * G.D) + h * G.D + d; - int cos_idx = b * (G.S_local * G.D) + s * G.D + d; - - // Partner indices for rotary math - int pair_d = (d < G.D / 2) ? (d + G.D / 2) : (d - G.D / 2); - int pair_idx = b * (G.S_local * G.H * G.D) + s * (G.H * G.D) + h * G.D + pair_d; - - // Load as bf16_2 mapping strictly to 4-byte aligned blocks - bf16_2 q_val = *reinterpret_cast(&G.q_local[in_idx]); - bf16_2 q_pair = *reinterpret_cast(&G.q_local[pair_idx]); - bf16_2 k_val = *reinterpret_cast(&G.k_local[in_idx]); - bf16_2 k_pair = *reinterpret_cast(&G.k_local[pair_idx]); - - bf16_2 cos_val = *reinterpret_cast(&G.cos_local[cos_idx]); - bf16_2 sin_val = *reinterpret_cast(&G.sin_local[cos_idx]); - - float2 q_f = __bfloat1622float2(q_val); - float2 q_pair_f = __bfloat1622float2(q_pair); - float2 k_f = __bfloat1622float2(k_val); - float2 k_pair_f = __bfloat1622float2(k_pair); - float2 cos_f = __bfloat1622float2(cos_val); - float2 sin_f = __bfloat1622float2(sin_val); - - float2 q_rot_f, k_rot_f; - if (d < G.D / 2) { - q_rot_f.x = -q_pair_f.x; q_rot_f.y = -q_pair_f.y; - k_rot_f.x = -k_pair_f.x; k_rot_f.y = -k_pair_f.y; - } else { - q_rot_f = q_pair_f; - k_rot_f = k_pair_f; - } - - float2 q_out_f; - q_out_f.x = q_f.x * cos_f.x + q_rot_f.x * sin_f.x; - q_out_f.y = q_f.y * cos_f.y + q_rot_f.y * sin_f.y; - - float2 k_out_f; - k_out_f.x = k_f.x * cos_f.x + k_rot_f.x * sin_f.x; - k_out_f.y = k_f.y * cos_f.y + k_rot_f.y * sin_f.y; - - bf16_2 q_out = __float22bfloat162_rn(q_out_f); - bf16_2 k_out = __float22bfloat162_rn(k_out_f); - - // Target global offset for gather-reconstruction - int s_out = s + G.dev_idx * G.S_local; - int S_global = G.S_local * globals::NUM_DEVICES; - int out_idx = b * (S_global * G.H * G.D) + s_out * (G.H * G.D) + h * G.D + d; - - // Broadcast computed slice directly to target locations matching sequence shards - kittens::multimem::st(reinterpret_cast(&G.q_out.mc_ptr[out_idx]), q_out); - kittens::multimem::st(reinterpret_cast(&G.k_out.mc_ptr[out_idx]), k_out); -} - -} // namespace rope_all_gather - -namespace rope_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace rope_barrier - -void entrypoint( - kittens::py::TKParallelTensor &q_out, - kittens::py::TKParallelTensor &k_out, - uintptr_t q_local_ptr, - uintptr_t k_local_ptr, - uintptr_t cos_local_ptr, - uintptr_t sin_local_ptr, - kittens::py::TKParallelTensor &barrier, - int B, int S_local, int H, int D -) { - kittens::py::parallel_tensor_check(q_out, k_out, barrier); - - rope_all_gather::globals rope_G { - .q_out = kittens::py::parallel_tensor_to_pgl(q_out), - .k_out = kittens::py::parallel_tensor_to_pgl(k_out), - .q_local = reinterpret_cast(q_local_ptr), - .k_local = reinterpret_cast(k_local_ptr), - .cos_local = reinterpret_cast(cos_local_ptr), - .sin_local = reinterpret_cast(sin_local_ptr), - .B = B, - .S_local = S_local, - .H = H, - .D = D, - .dev_idx = q_out.local_rank_ - }; - - rope_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronize cluster setup -> run async computation & multicast out -> synchronize visibility - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(rope_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_rope_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_rope_allgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - if not dist.is_initialized(): - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - q_local: torch.Tensor, - k_local: torch.Tensor, - cos_local: torch.Tensor, - sin_local: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - - B, S_local, H, D = q_local.shape - - # Kernel requires bfloat16 alignment matching 2 floats layout assumption (D % 4 == 0) and exactly NUM_DEVICES GPUs - if not dist.is_initialized() or dist.get_world_size() != NUM_DEVICES or D % 4 != 0: - cos = cos_local.unsqueeze(2) - sin = sin_local.unsqueeze(2) - - half_dim = D // 2 - q_x1, q_x2 = q_local[..., :half_dim], q_local[..., half_dim:] - q_rot = torch.cat((-q_x2, q_x1), dim=-1) - k_x1, k_x2 = k_local[..., :half_dim], k_local[..., half_dim:] - k_rot = torch.cat((-k_x2, k_x1), dim=-1) - - q_embed_local = (q_local * cos) + (q_rot * sin) - k_embed_local = (k_local * cos) + (k_rot * sin) - - if not dist.is_initialized(): - return q_embed_local, k_embed_local - - world_size = dist.get_world_size() - q_gather_list = [torch.empty_like(q_embed_local) for _ in range(world_size)] - k_gather_list = [torch.empty_like(k_embed_local) for _ in range(world_size)] - - dist.all_gather(q_gather_list, q_embed_local.contiguous()) - dist.all_gather(k_gather_list, k_embed_local.contiguous()) - - return torch.cat(q_gather_list, dim=1), torch.cat(k_gather_list, dim=1) - - world = dist.get_world_size() - S_global = S_local * world - n_out = B * S_global * H * D - - # Pad parallel dimensions for broker stability requirement - ALIGNMENT = 4096 - padded_out = ((n_out + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - q_local_c = q_local.to(torch.bfloat16).contiguous() - k_local_c = k_local.to(torch.bfloat16).contiguous() - cos_local_c = cos_local.to(torch.bfloat16).contiguous() - sin_local_c = sin_local.to(torch.bfloat16).contiguous() - - ext = _ensure_ext_jit() - - q_out_tk = get_or_create_parallel_tensor(ext, (padded_out,), torch.bfloat16, multicast=True) - k_out_tk = get_or_create_parallel_tensor(ext, (padded_out,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - ext.tk_rope_all_gather( - q_out_tk, - k_out_tk, - q_local_c.data_ptr(), - k_local_c.data_ptr(), - cos_local_c.data_ptr(), - sin_local_c.data_ptr(), - barrier_tk, - B, S_local, H, D - ) - - q_global = q_out_tk.data_[:n_out].view(B, S_global, H, D).clone() - k_global = k_out_tk.data_[:n_out].view(B, S_global, H, D).clone() - - # Clean fallback format logic - orig_dtype = q_local.dtype - if orig_dtype != torch.bfloat16: - q_global = q_global.to(orig_dtype) - k_global = k_global.to(orig_dtype) - - return q_global, k_global \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_parallelkittens.py deleted file mode 100755 index f29b781..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_parallelkittens.py +++ /dev/null @@ -1,294 +0,0 @@ -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for fused TP RMSNorm with ThunderKittens all-reduce -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include - -using namespace kittens; - -namespace tk_rms { - -// --------------------------------------------------------- -// 1. Local Sum of Squares Kernel -// --------------------------------------------------------- -__global__ void local_sum_kernel( - const __nv_bfloat16* __restrict__ x, - float* __restrict__ local_sums, - int D, int N) -{ - int row = blockIdx.x * blockDim.y + threadIdx.y; - if (row >= N) return; - - float sum = 0.0f; - for (int i = threadIdx.x; i < D; i += blockDim.x) { - float val = __bfloat162float(x[row * D + i]); - sum += val * val; - } - - // Warp-level reduction - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - - if (threadIdx.x == 0) { - local_sums[row] = sum; - } -} - -// --------------------------------------------------------- -// 2. Apply RMSNorm Kernel -// --------------------------------------------------------- -__global__ void apply_rmsnorm_kernel( - const __nv_bfloat16* __restrict__ x, - const __nv_bfloat16* __restrict__ weight, - __nv_bfloat16* __restrict__ y, - const float* __restrict__ global_sums, - float epsilon, - int D, int N, int global_D) -{ - int row = blockIdx.x * blockDim.y + threadIdx.y; - if (row >= N) return; - - float global_sum = global_sums[row]; - float variance = global_sum / global_D; - float rsqrt_var = rsqrtf(variance + epsilon); - - for (int i = threadIdx.x; i < D; i += blockDim.x) { - float val = __bfloat162float(x[row * D + i]); - float w = __bfloat162float(weight[i]); - float out = val * rsqrt_var * w; - y[row * D + i] = __float2bfloat16(out); - } -} - -// --------------------------------------------------------- -// 3. ThunderKittens Multimem All-Reduce and Barrier -// --------------------------------------------------------- -struct config_all_reduce { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 256; -}; - -struct globals_all_reduce { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 1; - static constexpr int NUM_ELEMS_PER_BLOCK = config_all_reduce::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel_all_reduce(const globals_all_reduce &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals_all_reduce::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals_all_reduce::NUM_ELEMS_PER_BLOCK * blockIdx.x + - threadIdx.x; - - float tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -struct config_barrier { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals_barrier { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel_barrier(const globals_barrier &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace tk_rms - -// --------------------------------------------------------- -// 4. Host Entrypoint -// --------------------------------------------------------- -void tk_fused_rms_norm( - torch::Tensor x, - torch::Tensor weight, - torch::Tensor y, - kittens::py::TKParallelTensor &sums, - kittens::py::TKParallelTensor &barrier, - float epsilon, - int global_D) -{ - int N = x.numel() / x.size(-1); - int D = x.size(-1); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - dim3 block(32, 8); // 8 warps, each processing a row - dim3 grid((N + 7) / 8); - - // Step 1: Compute local sums over the hidden dim - tk_rms::local_sum_kernel<<>>( - reinterpret_cast(x.data_ptr()), - reinterpret_cast(sums.data_.data_ptr()), - D, N - ); - - // Setup ThunderKittens globals - tk_rms::globals_barrier barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - tk_rms::globals_all_reduce all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(sums), - .dev_idx = sums.local_rank_ - }; - - // Step 2 & 3 & 4: In-place hardware multimem reduction wrapped with device-side barriers - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); - - // Step 5: Normalize and apply weight - tk_rms::apply_rmsnorm_kernel<<>>( - reinterpret_cast(x.data_ptr()), - reinterpret_cast(weight.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(y.data_ptr()), - reinterpret_cast(sums.data_.data_ptr()), - epsilon, D, N, global_D - ); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_rms_norm", &tk_fused_rms_norm); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -# Layout constants ensuring correct padding for TK kernel -NUM_DEVICES = 8 -NUM_ELEMS_PER_BLOCK = 256 -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 2048 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_rmsnorm_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - local_hidden_states: torch.Tensor, - local_weight: torch.Tensor, - variance_epsilon: float -) -> torch.Tensor: - - assert local_hidden_states.is_cuda and local_hidden_states.is_contiguous() - assert local_weight.is_cuda and local_weight.is_contiguous() - assert local_hidden_states.dtype == torch.bfloat16, "Kernel optimized for BF16 hidden states" - assert local_weight.dtype == torch.bfloat16, "Kernel optimized for BF16 weight" - - world = dist.get_world_size() - assert world == NUM_DEVICES, f"This ThunderKittens kernel targets {NUM_DEVICES} devices, got {world}" - - ext = _ensure_ext_jit() - - original_shape = local_hidden_states.shape - D = original_shape[-1] - N = local_hidden_states.numel() // D - - # Pad symmetric buffer size so TK can safely divide workload amongst all SMs - padded_N = ((N + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - x_flat = local_hidden_states.view(N, D) - out_flat = torch.empty_like(x_flat) - - # Cached ThunderKittens arrays (VMM + NVSwitch multicast context bindings) - sums_tk = get_or_create_parallel_tensor(ext, (padded_N,), torch.float32, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Initialize symmetric padding to 0 - sums_tk.data_.zero_() - - # Launch entire pipeline as one unified stream of C++ kernels - ext.tk_fused_rms_norm( - x_flat, - local_weight, - out_flat, - sums_tk, - barrier_tk, - float(variance_epsilon), - D * world - ) - - return out_flat.view(original_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_parallelkittens.py deleted file mode 100755 index e1d2b59..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_parallelkittens.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -Strategy: -1. Fuse the local block FP8 quantization with the communication step. -2. Utilize ThunderKittens' `multimem::st` (NVSwitch multicast) to perform an O(1) broadcast of the locally - quantized blocks and scales directly into the global all-gathered tensors of all peers. -3. By treating the all-gathered target buffer as a symmetric parallel tensor and computing appropriate - rank-based offsets, each GPU natively "all-gathers" simply by storing its own slice to the multicast - address, completely hiding communication behind the quantization math. -4. Use `__shfl_down_sync` warp reductions to quickly compute block scales and `__nv_fp8_e4m3` intrinsics - for single-pass natively saturated conversion. -""" - -import os -import torch -import torch.distributed as dist -import triton -import triton.language as tl -from typing import Tuple - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for fused TK quantization and multicast broadcast -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include - -using namespace kittens; - -namespace quantize { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - const __nv_bfloat16* input; - - using layout = pgl, NUM_DEVICES, true>; - layout y_global; - layout s_global; - - int64_t N; - int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((N + config::NUM_WARPS - 1) / config::NUM_WARPS); - } -}; - -__device__ inline void kernel(const globals &G) { - int64_t block_idx = (int64_t)blockIdx.x * config::NUM_WARPS + threadIdx.x / WARP_THREADS; - if (block_idx >= G.N) return; - - int warp_tid = threadIdx.x % WARP_THREADS; - int64_t local_offset = block_idx * 128 + warp_tid * 4; - - const uint16_t* p_u16 = reinterpret_cast(&G.input[local_offset]); - float x[4]; - - #pragma unroll - for(int i=0; i<4; ++i) { - uint32_t f_u32 = ((uint32_t)p_u16[i]) << 16; - x[i] = *reinterpret_cast(&f_u32); - } - - float local_max = 0.0f; - #pragma unroll - for(int i=0; i<4; ++i) { - local_max = fmaxf(local_max, fabsf(x[i])); - } - - unsigned int mask = 0xffffffff; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_max = fmaxf(local_max, __shfl_down_sync(mask, local_max, offset)); - } - - float block_max = __shfl_sync(mask, local_max, 0); - float s = block_max / 448.0f; - float s_safe = (s == 0.0f) ? 1.0f : s; - - if (warp_tid == 0) { - int64_t global_s_idx = (int64_t)G.dev_idx * G.N + block_idx; - bf16_2 s_val = *reinterpret_cast(&s_safe); - // Multiply by 2 because mc_ptr points to bf16 (2 bytes) and we write 4 bytes - kittens::multimem::st(reinterpret_cast(&G.s_global.mc_ptr[global_s_idx * 2]), s_val); - } - - uint32_t y_u32 = 0; - uint8_t* p_y = reinterpret_cast(&y_u32); - #pragma unroll - for(int i=0; i<4; ++i) { - float x_scaled = x[i] / s_safe; - __nv_fp8_e4m3 y_fp8(x_scaled); - p_y[i] = *reinterpret_cast(&y_fp8); - } - - int64_t global_y_u32_idx = ((int64_t)G.dev_idx * G.N * 128) / 4 + block_idx * 32 + warp_tid; - bf16_2 y_val = *reinterpret_cast(&y_u32); - // Multiply by 2 because mc_ptr points to bf16 (2 bytes) and we write 4 bytes - kittens::multimem::st(reinterpret_cast(&G.y_global.mc_ptr[global_y_u32_idx * 2]), y_val); -} - -} // namespace quantize - -namespace quantize_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace quantize_barrier - -void entrypoint( - torch::Tensor local_input, - kittens::py::TKParallelTensor &y_tk, - kittens::py::TKParallelTensor &s_tk, - kittens::py::TKParallelTensor &barrier -) { - TORCH_CHECK(local_input.is_contiguous(), "Input must be contiguous"); - TORCH_CHECK(local_input.numel() % 128 == 0, "Elements must be multiple of 128"); - - kittens::py::parallel_tensor_check(y_tk, s_tk, barrier); - - int64_t N = local_input.numel() / 128; - - quantize::globals G { - .input = reinterpret_cast(local_input.data_ptr()), - .y_global = kittens::py::parallel_tensor_to_pgl(y_tk), - .s_global = kittens::py::parallel_tensor_to_pgl(s_tk), - .N = N, - .dev_idx = y_tk.local_rank_ - }; - - quantize_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_quantize", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_quantize_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - -# Fallback Triton kernel for single-GPU / missing distributed scenarios -@triton.jit -def block_fp8_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 - s_safe = tl.where(s == 0.0, 1.0, s) - y = (x / s_safe).to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + pid, s) - -@torch.no_grad() -def solution(local_tensor: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert local_tensor.is_contiguous(), "Input tensor must be contiguous" - assert local_tensor.size(-1) % block_size == 0, "Last dimension must be divisible by block_size" - assert block_size == 128, "This optimized kernel requires block_size=128" - - world_size = dist.get_world_size() if dist.is_initialized() else 1 - - if world_size == 1: - # Fallback for local testing or single GPU runs - y_local = torch.empty_like(local_tensor, dtype=torch.float8_e4m3fn) - s_local = local_tensor.new_empty( - *local_tensor.size()[:-1], local_tensor.size(-1) // block_size, dtype=torch.float32 - ) - grid = (triton.cdiv(local_tensor.numel(), block_size),) - block_fp8_quant_kernel[grid](local_tensor, y_local, s_local, BLOCK_SIZE=block_size) - return y_local, s_local - - assert world_size == 8, "ThunderKittens kernel built for 8 GPUs" - - ext = _ensure_ext_jit() - - original_shape = local_tensor.shape - local_tensor_bf16 = local_tensor.to(torch.bfloat16) - - L = local_tensor_bf16.numel() - N = L // 128 - - # y_tk needs W * L bytes of output (float8) - # We allocate it as bfloat16, meaning we allocate (W * L // 2) bf16 elements. - y_tk = get_or_create_parallel_tensor(ext, (world_size * L // 2,), torch.bfloat16, multicast=True) - - # s_tk needs W * N floats of output (4 bytes each). - # We allocate it as bfloat16, meaning we allocate (W * N * 2) bf16 elements. - s_tk = get_or_create_parallel_tensor(ext, (world_size * N * 2,), torch.bfloat16, multicast=True) - - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Run the fused quantization + multicast broadcast - ext.tk_quantize(local_tensor_bf16, y_tk, s_tk, barrier_tk) - - # Shape of reference output equivalent to `torch.cat(gather_list, dim=0)` - target_shape = list(original_shape) - if len(target_shape) > 0: - target_shape[0] *= world_size - else: - target_shape = [world_size] - - s_target_shape = list(original_shape) - s_target_shape[-1] = s_target_shape[-1] // 128 - if len(s_target_shape) > 0: - s_target_shape[0] *= world_size - else: - s_target_shape = [world_size] - - # Cast symmetrical buffers back into desired view shapes - y_global = y_tk.data_.view(-1)[:world_size * L // 2].view(torch.uint8).view(torch.float8_e4m3fn) - y_global = y_global.reshape(target_shape) - - s_global = s_tk.data_.view(-1)[:world_size * N * 2].view(torch.float32) - s_global = s_global.reshape(s_target_shape) - - return y_global, s_global \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_parallelkittens.py deleted file mode 100755 index 85bd0d2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_parallelkittens.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -ThunderKittens All-Reduce (SUM) using PGL with NVSwitch multicast. - -Uses TKParallelTensor for VMM allocation, IPC handle exchange, and -multicast setup. The kernel performs in-switch reduction via -multimem.ld_reduce and broadcasts the result via multimem.st. - -Requires: ThunderKittens headers at $THUNDERKITTENS_ROOT/include. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allreduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, f"This kernel is optimized for {NUM_DEVICES} devices" - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - n = flat.numel() - - # Pad to kernel alignment (NUM_DEVICES * NUM_ELEMS_PER_BLOCK) - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - # Cached TKParallelTensor (VMM + multicast) — steady-state like NCCL benchmarks. - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Copy input into the VMM-allocated parallel tensor - tensor_tk.data_[:n] = flat - if padded > n: - tensor_tk.data_[n:].zero_() - - # Run the TK all-reduce (barrier → reduce → barrier) - ext.tk_all_reduce(tensor_tk, barrier_tk) - - result = tensor_tk.data_[:n].clone() - return result.to(original_dtype).reshape(original_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_parallelkittens.py deleted file mode 100755 index 34e165b..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_parallelkittens.py +++ /dev/null @@ -1,378 +0,0 @@ -""" -Strategy: -We bypass the standard PyTorch NCCL `all_to_all_single` and intermediate allocations by integrating custom Triton kernels with ThunderKittens' native TMA all-to-all. -1. **Fused Dequantization & Padding:** A Triton kernel (`block_fp8_dequant_pad_kernel`) performs FP8-to-BF16 block dequantization while directly scattering the output into the 16x128 tiled memory layout required by ThunderKittens. This overlaps the scale arithmetic with the padding layout transformations. -2. **Device-Side TMA Exchange:** We utilize a custom ThunderKittens kernel (`tk_all_to_all`) to perform personalized all-to-all communication. This leverages Hopper's Tensor Memory Accelerator (TMA) to move blocks asynchronously over NVLink without host-driven NCCL overhead. -3. **Fused Unpadding & Cast:** A second Triton kernel slices out the valid payload from the 16x128 TMA tiles and casts the received BF16 data to FP32 directly into the final contiguous output tensor, merging memory movement and data formatting on-device. -""" - -import os -import torch -import torch.distributed as dist -import triton -import triton.language as tl -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - -# --------------------------------------------------------------------------- -# Fused Triton Kernels -# --------------------------------------------------------------------------- - -@triton.jit -def block_fp8_dequant_pad_kernel( - y_ptr, s_ptr, out_ptr, - num_elements, chunk_size, padded_chunk_size, - scale_block_size, - BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < num_elements - - chunk_idx = offs // chunk_size - idx_in_chunk = offs % chunk_size - - scale_idx = offs // scale_block_size - s = tl.load(s_ptr + scale_idx, mask=mask) - y = tl.load(y_ptr + offs, mask=mask).to(tl.float32) - - val = y * s - out_idx = chunk_idx * padded_chunk_size + idx_in_chunk - tl.store(out_ptr + out_idx, val.to(tl.bfloat16), mask=mask) - -@triton.jit -def unpad_cast_fp32_kernel( - in_ptr, out_ptr, - num_elements, chunk_size, padded_chunk_size, - BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < num_elements - - chunk_idx = offs // chunk_size - idx_in_chunk = offs % chunk_size - - in_idx = chunk_idx * padded_chunk_size + idx_in_chunk - val = tl.load(in_ptr + in_idx, mask=mask).to(tl.float32) - tl.store(out_ptr + offs, val, mask=mask) - -# --------------------------------------------------------------------------- -# Main Implementation -# --------------------------------------------------------------------------- - -@torch.no_grad() -def solution( - local_y: torch.Tensor, - local_s: torch.Tensor, - block_size: int = 128, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - assert local_y.dim() >= 1 and local_y.shape[0] == world_size, \ - f"local_y first dimension must equal world_size ({world_size}), got {local_y.shape[0]}" - assert world_size == NUM_DEVICES, \ - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; got {world_size}" - - assert local_y.is_contiguous(), "Input tensor local_y must be contiguous" - assert local_s.is_contiguous(), "Scale tensor local_s must be contiguous" - - chunk_shape = local_y.shape[1:] - chunk_size = local_y.numel() // world_size - num_elements = local_y.numel() - - if num_elements == 0: - return torch.empty(world_size, *chunk_shape, device=local_y.device, dtype=torch.float32) - - assert chunk_size % block_size == 0, \ - f"Chunk size {chunk_size} must be divisible by block_size ({block_size})" - - ext = _ensure_ext_jit() - - r, c, padded_chunk_size = _padded_row_col(chunk_size) - - # Pre-allocate TK tensors (using caching under the hood) - input_tk = get_or_create_parallel_tensor( - ext, (world_size, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, world_size, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - in_tk_flat = input_tk.data_.view(-1) - out_tk_flat = output_tk.data_.view(-1) - - # Zero buffer to prevent transferring garbage within unused padding regions - in_tk_flat.zero_() - - triton_block_size = 1024 - grid = (triton.cdiv(num_elements, triton_block_size),) - - # 1. Fused dequantization + pad to ThunderKittens input layout - block_fp8_dequant_pad_kernel[grid]( - local_y.view(-1), local_s.view(-1), in_tk_flat, - num_elements, chunk_size, padded_chunk_size, - block_size, - BLOCK_SIZE=triton_block_size - ) - - # 2. ThunderKittens device-side TMA Exchange - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - # 3. Fused unpad + FP32 cast directly to final destination - out = torch.empty(world_size, *chunk_shape, device=local_y.device, dtype=torch.float32) - unpad_cast_fp32_kernel[grid]( - out_tk_flat, out.view(-1), - num_elements, chunk_size, padded_chunk_size, - BLOCK_SIZE=triton_block_size - ) - - return out \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_parallelkittens.py deleted file mode 100755 index dba57a2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_parallelkittens.py +++ /dev/null @@ -1,309 +0,0 @@ -""" -Standalone L2 clip_grad_norm: FSDP2 path WITHOUT EP. -Optimized with ThunderKittens PGL multicast and custom fused CUDA reductions. -""" - -import math -import os -from typing import List, Optional - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for fused computation and TK all-reduce -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include -#include - -using namespace kittens; - -// ============================================================================ -// 1. Local Norm Squared Reduction (Fused across tensors) -// ============================================================================ - -__global__ void local_norm_sq_kernel(const __nv_bfloat16* g, int numel, float* acc) { - // Accumulate locally into double to ensure complete numerical stability - double thread_sum = 0; - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - for (int i = idx; i < numel; i += blockDim.x * gridDim.x) { - float v = __bfloat162float(g[i]); - thread_sum += (double)v * (double)v; - } - - float sum = static_cast(thread_sum); - - // Warp reduction - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - - // Block reduction - extern __shared__ float shared[]; - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - if (lane == 0) shared[wid] = sum; - __syncthreads(); - - if (wid == 0) { - float val = (lane < (blockDim.x / 32)) ? shared[lane] : 0.0f; - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (lane == 0) { - atomicAdd(acc, val); - } - } -} - -void compute_local_norm_sq(std::vector tensors, at::Tensor acc) { - float* acc_ptr = acc.data_ptr(); - cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); - - // Zero out the entire buffer (including trailing TK alignment paddings) - cudaMemsetAsync(acc_ptr, 0, acc.numel() * sizeof(float), stream); - - for (const auto& t : tensors) { - if (!t.defined() || t.numel() == 0) continue; - int numel = t.numel(); - int threads = 256; - int blocks = std::min((numel + threads - 1) / threads, 1024); - int shared_mem = (threads / 32) * sizeof(float); - - local_norm_sq_kernel<<>>( - reinterpret_cast(t.data_ptr()), numel, acc_ptr - ); - } -} - -// ============================================================================ -// 2. ThunderKittens Hopper Float All-Reduce (Multicast) -// ============================================================================ - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 1; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - float tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void tk_all_reduce( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "Total tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -// ============================================================================ -// 3. Batched Scale Pass -// ============================================================================ - -__global__ void scale_tensors_kernel(__nv_bfloat16* g, int numel, float coef) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = idx; i < numel; i += blockDim.x * gridDim.x) { - float v = __bfloat162float(g[i]); - g[i] = __float2bfloat16(v * coef); - } -} - -void scale_tensors(std::vector tensors, float coef) { - cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); - for (const auto& t : tensors) { - if (!t.defined() || t.numel() == 0) continue; - int numel = t.numel(); - int threads = 256; - int blocks = std::min((numel + threads - 1) / threads, 1024); - scale_tensors_kernel<<>>( - reinterpret_cast<__nv_bfloat16*>(t.data_ptr()), numel, coef - ); - } -} - -// ============================================================================ -// PyBind Initialization -// ============================================================================ - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("compute_local_norm_sq", &compute_local_norm_sq); - m.def("tk_all_reduce", &tk_all_reduce); - m.def("scale_tensors", &scale_tensors); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -# Hardcoded node size corresponding to standard H100 cluster limits -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_clip_grad_norm_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call barriers in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - grad_tensors: List[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - fsdp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - assert float(norm_type) == 2.0, "Only L2 norm is supported by this optimized path" - - world = dist.get_world_size(group=fsdp_group) if fsdp_group is not None else dist.get_world_size() - assert world == NUM_DEVICES, f"This ThunderKittens kernel is compiled for NUM_DEVICES={NUM_DEVICES}; got world_size={world}" - - ext = _ensure_ext_jit() - - # TK requires the local buffer space to neatly split into thread blocks (ALIGNMENT elements) - ALIGNMENT = world * 256 - tensor_tk = get_or_create_parallel_tensor(ext, (ALIGNMENT,), torch.float32, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - valid_tensors = [t for t in grad_tensors if t is not None] - - # 1. Pipeline local fused accumulation directly into symmetric tensor slot [0] - ext.compute_local_norm_sq(valid_tensors, tensor_tk.data_) - - # 2. ThunderKittens NVSwitch barrier + load/reduce multicast + barrier - ext.tk_all_reduce(tensor_tk, barrier_tk) - - # 3. Read accumulated scalar norms and globally trigger scalings - total_norm_sq = tensor_tk.data_[0].clone() - total_norm = total_norm_sq.sqrt() - - if total_norm > max_norm: - coef = max_norm / total_norm.item() - ext.scale_tensors(valid_tensors, coef) - - return total_norm \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_parallelkittens.py deleted file mode 100755 index b8ca01a..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_parallelkittens.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Strategy: -- **Math Fusion & Compute-Communication Overlap**: Combines the previously independent `fsdp_group` and `ep_group` gradient reductions into a **single ThunderKittens global all-reduce**. We leverage FSDP replica symmetry to mathematically map sub-group sums to a full 8-GPU node sum, fully eliminating the multiple serialized sub-group collectives. -- **Device-Side Symmetric Reductions**: Replaces opaque NCCL host launches with `TKParallelTensor` NVSwitch multicast and `multimem.ld_reduce` for peer-to-peer 8-way reduction of the local norm scalars directly in GPU memory. -- **MultiTensorApply Fast Paths**: Strips out Python loops for local $L^2$ norm accumulations, fusing all gradient reads (both EP and non-EP simultaneously) into a single highly optimized `torch._foreach_norm` dispatch to maximize memory bandwidth. -""" - -import math -import os -from typing import List, Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded ThunderKittens All-Reduce Kernel for fast symmetric reduction -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allreduce_ext_clipgrad", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - non_ep_grad_tensors: List[torch.Tensor], - ep_grad_tensors: List[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - ep_size: int = 1, - fsdp_group: Optional[dist.ProcessGroup] = None, - ep_fsdp_group: Optional[dist.ProcessGroup] = None, - ep_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - world = dist.get_world_size() - assert world == NUM_DEVICES, f"ThunderKittens all-reduce kernel built for NUM_DEVICES={NUM_DEVICES}" - - ext = _ensure_ext_jit() - - # 1. EP gradients scaling via fast inplace MultiTensorApply - valid_ep = [g for g in ep_grad_tensors if g is not None] - if ep_size > 1 and valid_ep: - scale = 1.0 / float(ep_size) - torch._foreach_mul_(valid_ep, scale) - - # 2. Local L2 squared norm calculation - p = float(norm_type) - valid_non_ep = [g for g in non_ep_grad_tensors if g is not None] - all_valid = valid_non_ep + valid_ep - - if all_valid: - # Fuse ALL gradient norm calculations into a single _foreach read operation - fp32_all = [g.detach().to(torch.float32, copy=False) for g in all_valid] - norms_all = torch._foreach_norm(fp32_all, p) - - num_non_ep = len(valid_non_ep) - norms_non_ep = norms_all[:num_non_ep] - norms_ep = norms_all[num_non_ep:] - - if norms_non_ep: - non_ep_local = torch.sum(torch.stack(norms_non_ep) ** p) - else: - non_ep_local = torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device()) - - if norms_ep: - ep_local = torch.sum(torch.stack(norms_ep) ** p) - else: - ep_local = torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device()) - else: - non_ep_local = torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device()) - ep_local = torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device()) - - # Pack into a 2-element contiguous block for the TK all-reduce - local_sums = torch.stack([non_ep_local, ep_local]).to(torch.bfloat16) - - # 3. Fast device-side ThunderKittens 8-way combined All-Reduce - tensor_tk = get_or_create_parallel_tensor(ext, (ALIGNMENT,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - tensor_tk.data_[:2] = local_sums - if ALIGNMENT > 2: - tensor_tk.data_[2:ALIGNMENT].zero_() - - ext.tk_all_reduce(tensor_tk, barrier_tk) - - reduced_sums = tensor_tk.data_[:2].to(torch.float32) - - # 4. Math Translation: Model replications allow mapping sub-group sums out of the 8-way sum - fsdp_sz = dist.get_world_size(fsdp_group) if fsdp_group is not None else 1 - non_ep_div = float(world) / float(fsdp_sz) - - ep_fsdp_sz = dist.get_world_size(ep_fsdp_group) if ep_fsdp_group is not None else 1 - ep_sz_group = dist.get_world_size(ep_group) if ep_group is not None else 1 - ep_div = float(world) / float(ep_fsdp_sz * ep_sz_group) - - global_non_ep = reduced_sums[0] / non_ep_div - global_ep = reduced_sums[1] / ep_div - - total_norm = (global_non_ep + global_ep) ** (1.0 / p) - - # 5. Conditional clipping - if total_norm > max_norm: - coef_val = float(max_norm / total_norm) - if valid_non_ep: - torch._foreach_mul_(valid_non_ep, coef_val) - if valid_ep: - torch._foreach_mul_(valid_ep, coef_val) - - return total_norm \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_parallelkittens.py deleted file mode 100755 index 5596465..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_parallelkittens.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -ThunderKittens Grad Acc Loss Integration - -Replaces the CPU-syncing `.item()` check with a device-side mask and utilizes -a high-speed ThunderKittens PGL multicast all-reduce for the scalar aggregation. -""" - -import os -import torch -import torch.distributed as dist -from typing import Tuple, Optional - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: ThunderKittens All-Reduce (SUM) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allreduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _tk_all_reduce_sum(tensor: torch.Tensor, ext) -> torch.Tensor: - """Wrapper to map PyTorch arbitrary tensors through the fixed TK buffer alignment.""" - world = dist.get_world_size() - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - n = flat.numel() - - # Pad to kernel alignment (NUM_DEVICES * NUM_ELEMS_PER_BLOCK) - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - # Cached TKParallelTensor (VMM + multicast) - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Copy input into the VMM-allocated parallel tensor - tensor_tk.data_[:n] = flat - if padded > n: - tensor_tk.data_[n:].zero_() - - # Run the TK all-reduce (barrier → reduce → barrier) - ext.tk_all_reduce(tensor_tk, barrier_tk) - - result = tensor_tk.data_[:n].clone() - return result.to(original_dtype).reshape(original_shape) - - -@torch.no_grad() -def solution( - loss: torch.Tensor, - local_valid_tokens: torch.Tensor, - global_valid_tokens: torch.Tensor, - grad_normalized_loss: torch.Tensor, - grad_loss_sum: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Evaluates explicitly configured forward/backward metrics utilizing - device-side masks and TK multicaset kernels for minimum synchronization. - """ - ext = _ensure_ext_jit() - - # --------------------------------------------------------- - # Forward Pass - # --------------------------------------------------------- - - # Eliminate host-device synchronization by replacing .item() with device logic - loss_clean = torch.nan_to_num(loss) - loss_val = torch.where(local_valid_tokens == 0, loss_clean, loss) - - loss_sum = loss_val * local_valid_tokens - - # TK All-Reduce Sum Collective replacing NCCL dist.all_reduce - loss_sum = _tk_all_reduce_sum(loss_sum, ext) - - normalized_loss = loss_sum / global_valid_tokens - - # --------------------------------------------------------- - # Backward Pass - # --------------------------------------------------------- - grad_from_normalized = grad_normalized_loss * local_valid_tokens / global_valid_tokens - - if grad_loss_sum is not None: - grad_from_sum = grad_loss_sum * local_valid_tokens - else: - grad_from_sum = torch.zeros_like(grad_normalized_loss) - - grad_loss = grad_from_normalized + grad_from_sum - - return normalized_loss, loss_sum, grad_loss \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_parallelkittens.py deleted file mode 100755 index 37a4165..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_parallelkittens.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Strategy: -1. Optimize the local device-side computation by replacing memory-heavy standard ops - (like `one_hot` followed by `mean` and `unsqueeze`, which allocate huge intermediates - of shape [N, top_k, num_experts]) with fast `bincount` and matrix-vector multiplications - (`routing_weights.T @ am_repeated`). This eliminates massive memory spikes and bandwidth - bottlenecks on the hot path before communication even begins. -2. Replace opaque `torch.distributed.all_reduce` with a custom ThunderKittens PGL kernel - using Hopper's multimem and NVSwitch multicast. The scalar loss is mapped directly - into symmetric memory and all-reduced via `multimem.ld_reduce`, keeping - synchronization entirely device-side and minimizing host overhead for the reduction. -""" - -import os -import torch -import torch.distributed as dist -from typing import Union, Tuple, Optional -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (ThunderKittens All-Reduce via Multicast) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allreduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]], - num_experts: int, - top_k: int = 2, - attention_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - - if isinstance(gate_logits, (tuple, list)): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat( - [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 - ) - else: - compute_device = gate_logits.device - concatenated_gate_logits = gate_logits - - # Fast local soft-max and top-k router selection - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - N = concatenated_gate_logits.shape[0] - - # Fast local binning & token probabilities, bypassing huge `one_hot` mask allocations - if attention_mask is None: - tokens_per_expert = torch.zeros((top_k, num_experts), dtype=torch.float32, device=compute_device) - for k in range(top_k): - counts = torch.bincount(selected_experts[:, k], minlength=num_experts) - tokens_per_expert[k] = counts.float() - tokens_per_expert /= N - - router_prob_per_expert = routing_weights.mean(dim=0) - else: - am_1d = attention_mask.reshape(-1).to(dtype=torch.float32, device=compute_device) - num_hidden_layers = N // am_1d.shape[0] - am_repeated = am_1d.repeat(num_hidden_layers) - - sum_am = am_repeated.sum() - - tokens_per_expert = torch.zeros((top_k, num_experts), dtype=torch.float32, device=compute_device) - for k in range(top_k): - counts = torch.bincount(selected_experts[:, k], weights=am_repeated, minlength=num_experts) - tokens_per_expert[k] = counts.float() - tokens_per_expert /= sum_am - - # Matrix-vector mult eliminates expanding/allocating massive attention masks - router_prob_per_expert = (routing_weights.T @ am_repeated) / sum_am - - # Local scalar loss (equivalent to summing over all expert elements after multiplication) - tokens_sum = tokens_per_expert.sum(dim=0) - overall_loss = torch.dot(tokens_sum, router_prob_per_expert) * num_experts - - # Device-side all-reduce over ThunderKittens PGL NVSwitch Multicast - if dist.is_available() and dist.is_initialized(): - world = dist.get_world_size() - assert world == NUM_DEVICES, f"ThunderKittens kernel built for {NUM_DEVICES} devices, but got {world}" - - ext = _ensure_ext_jit() - - flat = overall_loss.to(torch.bfloat16).view(-1).contiguous() - n = flat.numel() - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - tensor_tk.data_[:n] = flat - if padded > n: - tensor_tk.data_[n:padded].zero_() - - ext.tk_all_reduce(tensor_tk, barrier_tk) - - result = tensor_tk.data_[:n].clone() - overall_loss = result.to(overall_loss.dtype).view(overall_loss.shape) / world - - return overall_loss \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_parallelkittens.py deleted file mode 100755 index 50a3506..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_parallelkittens.py +++ /dev/null @@ -1,376 +0,0 @@ -import os -import torch -import torch.nn.functional as F -import torch.distributed as dist -from typing import Tuple, Any -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include -#include - -using namespace kittens; - -// Custom atomic min/max for float mappings -__device__ void atomicMinFloat(float* address, float val) { - int* address_as_int = (int*)address; - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, __float_as_int(fminf(val, __int_as_float(assumed)))); - } while (assumed != old); -} - -__device__ void atomicMaxFloat(float* address, float val) { - int* address_as_int = (int*)address; - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); - } while (assumed != old); -} - -// Fused kernel for importance sampling pointwise ops and local stats reduction -__global__ void local_compute_kernel( - const __nv_bfloat16* __restrict__ per_token_ce, - const __nv_bfloat16* __restrict__ old_logprobs, - const __nv_bfloat16* __restrict__ advantages, - const int64_t* __restrict__ labels, - int ignore_index, - int N, - __nv_bfloat16* __restrict__ per_token_pg, - __nv_bfloat16* __restrict__ per_token_logprobs, - __nv_bfloat16* __restrict__ w, - float* __restrict__ local_stats -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - float sum_n_valid = 0; - float sum_pg = 0; - float sum_surrogate = 0; - float sum_ratio = 0; - float min_ratio = INFINITY; - float max_ratio = -INFINITY; - float sum_k3 = 0; - float sum_entropy = 0; - - for (int i = idx; i < N; i += blockDim.x * gridDim.x) { - int64_t label = labels[i]; - float ce = __bfloat162float(per_token_ce[i]); - float old_lp = __bfloat162float(old_logprobs[i]); - float adv = __bfloat162float(advantages[i]); - - float new_lp = -ce; - per_token_logprobs[i] = __float2bfloat16(new_lp); - - if (label != ignore_index) { - sum_n_valid += 1.0f; - - float delta = new_lp - old_lp; - delta = fmaxf(-20.0f, fminf(20.0f, delta)); - float ratio = expf(delta); - - float pg = -(ratio * adv); - per_token_pg[i] = __float2bfloat16(pg); - sum_pg += pg; - - float w_val = ratio * adv; - w[i] = __float2bfloat16(w_val); - sum_surrogate += w_val * ce; - - sum_ratio += ratio; - min_ratio = fminf(min_ratio, ratio); - max_ratio = fmaxf(max_ratio, ratio); - - float k3 = ratio - delta - 1.0f; - sum_k3 += k3; - - sum_entropy += ce; - } else { - per_token_pg[i] = __float2bfloat16(0.0f); - w[i] = __float2bfloat16(0.0f); - } - } - - // Warp-level reduction - unsigned int mask = 0xffffffff; - for (int offset = 16; offset > 0; offset /= 2) { - sum_n_valid += __shfl_down_sync(mask, sum_n_valid, offset); - sum_pg += __shfl_down_sync(mask, sum_pg, offset); - sum_surrogate += __shfl_down_sync(mask, sum_surrogate, offset); - sum_ratio += __shfl_down_sync(mask, sum_ratio, offset); - min_ratio = fminf(min_ratio, __shfl_down_sync(mask, min_ratio, offset)); - max_ratio = fmaxf(max_ratio, __shfl_down_sync(mask, max_ratio, offset)); - sum_k3 += __shfl_down_sync(mask, sum_k3, offset); - sum_entropy += __shfl_down_sync(mask, sum_entropy, offset); - } - - // Leader threads commit to global shared buffer safely - if (idx % 32 == 0) { - atomicAdd(&local_stats[0], sum_n_valid); - atomicAdd(&local_stats[1], sum_pg); - atomicAdd(&local_stats[2], sum_surrogate); - atomicAdd(&local_stats[3], sum_ratio); - atomicMinFloat(&local_stats[4], min_ratio); - atomicMaxFloat(&local_stats[5], max_ratio); - atomicAdd(&local_stats[6], sum_k3); - atomicAdd(&local_stats[7], sum_entropy); - } -} - -namespace grpo_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace grpo_barrier - -// Compact P2P Read kernel targeting symmetric memory via peer pointers -namespace grpo_reduce { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - float* peer_ptrs[NUM_DEVICES]; - float* global_out; - int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - float sums[8] = {0}; - sums[4] = INFINITY; - sums[5] = -INFINITY; - - #pragma unroll - for (int i = 0; i < G.NUM_DEVICES; i++) { - sums[0] += G.peer_ptrs[i][0]; // n_valid - sums[1] += G.peer_ptrs[i][1]; // pg_sum - // [2] local_surrogate_sum not cross-device reduced - sums[3] += G.peer_ptrs[i][3]; // sum_ratio - sums[4] = fminf(sums[4], G.peer_ptrs[i][4]); // min_ratio - sums[5] = fmaxf(sums[5], G.peer_ptrs[i][5]); // max_ratio - sums[6] += G.peer_ptrs[i][6]; // k3 - sums[7] += G.peer_ptrs[i][7]; // entropy - } - - // Store back strictly to local results - G.global_out[0] = sums[0]; - G.global_out[1] = sums[1]; - G.global_out[2] = G.peer_ptrs[G.dev_idx][2]; - G.global_out[3] = sums[3]; - G.global_out[4] = sums[4]; - G.global_out[5] = sums[5]; - G.global_out[6] = sums[6]; - G.global_out[7] = sums[7]; -} -} // namespace grpo_reduce - -void entrypoint( - const torch::Tensor& per_token_ce, - const torch::Tensor& old_logprobs, - const torch::Tensor& advantages, - const torch::Tensor& labels, - int ignore_index, - torch::Tensor& per_token_pg, - torch::Tensor& per_token_logprobs, - torch::Tensor& w, - kittens::py::TKParallelTensor& local_stats_tk, - torch::Tensor& global_stats, - kittens::py::TKParallelTensor& barrier -) { - int N = per_token_ce.numel(); - int num_threads = 256; - int num_blocks = std::min((N + num_threads - 1) / num_threads, 1024); - - float* local_stats_ptr = reinterpret_cast(local_stats_tk.data_.data_ptr()); - - // 1. Compute local values into local arrays natively - local_compute_kernel<<>>( - reinterpret_cast(per_token_ce.data_ptr()), - reinterpret_cast(old_logprobs.data_ptr()), - reinterpret_cast(advantages.data_ptr()), - labels.data_ptr(), - ignore_index, - N, - reinterpret_cast<__nv_bfloat16*>(per_token_pg.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(per_token_logprobs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(w.data_ptr()), - local_stats_ptr - ); - - // 2. Safely barrier for visibility - grpo_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(barrier_G); - - // 3. One pass 8-float reduction fetching via fast symmetric peer pointers - grpo_reduce::globals reduce_G; - reduce_G.global_out = global_stats.data_ptr(); - reduce_G.dev_idx = local_stats_tk.local_rank_; - for(int i = 0; i < grpo_reduce::globals::NUM_DEVICES; i++) { - reduce_G.peer_ptrs[i] = reinterpret_cast(local_stats_tk.ptrs_[i]); - } - kittens::py::launch_kernel(reduce_G); - - // 4. Safely barrier before exit to shield the TK array overwrites on the next loop iteration - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_grpo_step", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False -_init_vals = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_grpo_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def solution( - hidden_states: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - old_logprobs: torch.Tensor, - advantages: torch.Tensor, - ignore_index: int = -100, -) -> Tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor]: - - ext = _ensure_ext_jit() - world_size = dist.get_world_size() - - # 1. Compute massive dense arrays explicitly in highly-optimized PyTorch components - logits = F.linear(hidden_states, weight) - logits_flat = logits.view(-1, logits.size(-1)) - labels_flat = labels.to(torch.int64).contiguous().view(-1) - - # Needs to require grad to pipe correctly back into logits - per_token_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=ignore_index, reduction='none') - - # 2. Allocate variables - old_logprobs_flat = old_logprobs.contiguous().view(-1) - advantages_flat = advantages.contiguous().view(-1) - - per_token_pg = torch.empty_like(per_token_ce) - per_token_logprobs = torch.empty_like(per_token_ce) - w = torch.empty_like(per_token_ce) - - # 3. Setup ThunderKittens reduction structure and zero buffers natively async - tk_stats = get_or_create_parallel_tensor(ext, (8,), torch.float32, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - global_stats = torch.empty(8, dtype=torch.float32, device=hidden_states.device) - - global _init_vals - if _init_vals is None or _init_vals.device != hidden_states.device: - _init_vals = torch.tensor([0.0, 0.0, 0.0, 0.0, float('inf'), float('-inf'), 0.0, 0.0], - dtype=torch.float32, device=hidden_states.device) - tk_stats.data_.copy_(_init_vals) - - # 4. Fused device-side operation - ext.tk_grpo_step( - per_token_ce, - old_logprobs_flat, - advantages_flat, - labels_flat, - ignore_index, - per_token_pg, - per_token_logprobs, - w, - tk_stats, - global_stats, - barrier_tk - ) - - # 5. Connect autograd graph - n_valid_global = global_stats[0].clamp(min=1.0) - global_pg_sum = global_stats[1] - - # Recover autograd graph cleanly; w is implicitly detached - surrogate = (w * per_token_ce).sum() / n_valid_global - true_pg = global_pg_sum / n_valid_global - - # Forward pass equals `true_pg`, backward is `surrogate` gradient - loss = true_pg + surrogate - surrogate.detach() - - # Metrics - ratio_mean = global_stats[3] / n_valid_global - min_ratio = global_stats[4] - max_ratio = global_stats[5] - k3_mean = global_stats[6] / n_valid_global - entropy_mean = global_stats[7] / n_valid_global - metrics = torch.stack([ratio_mean, min_ratio, max_ratio, k3_mean, entropy_mean]) - - return loss, None, per_token_logprobs.view_as(labels), per_token_pg.view_as(labels), metrics \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_parallelkittens.py deleted file mode 100755 index bbb67b4..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_parallelkittens.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -ThunderKittens MoE Token Preprocess (Expert Parallel). -Replaces NCCL all_gather with TMA-based peer-to-peer all-gather. -Defers blocking synchronizations to overlap compute and communication. -""" - -import os -from typing import List, Optional, Tuple - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (TMA all-gather entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace tk_all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; // TMA driven -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(NUM_DEVICES * (input.rows() / ROW_BLOCK_SIZE) * (input.cols() / COL_BLOCK_SIZE)); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int src_dev_idx = task_idx / ((G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= ((G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx % (G.input.cols() / globals::COL_BLOCK_SIZE); - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // Load tile from src_dev_idx's input tensor - tma::load_async(tile, G.input[src_dev_idx], {0, 0, row_block_idx, col_block_idx}, arrived); - wait(arrived, 0); - - // Store tile into local output tensor at the appropriate depth (source rank) - tma::store_async(G.output[G.dev_idx], tile, {0, src_dev_idx, row_block_idx, col_block_idx}); -} - -} // namespace tk_all_gather - -namespace tk_all_gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace tk_all_gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, input); - - tk_all_gather::globals all_gather_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - tk_all_gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Symmetric sync -> TMA Gather -> Symmetric sync - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_gather_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allgather_preprocess_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) properly matching TK TMA tile semantics.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -def solution( - expert_mask: torch.Tensor, - num_experts: int, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[List[int], List[int], torch.Tensor, torch.Tensor]: - """ - Compute splits and routing token totals using overlapped TMA kernels - for communication instead of NCCL blocks. - """ - group = group or dist.group.WORLD - ep_size = dist.get_world_size(group) - rank = dist.get_rank(group) - num_local_experts = num_experts // ep_size - - assert ep_size == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={ep_size}" - ) - - ext = _ensure_ext_jit() - - # 1. Local reductions (hot math) - num_local_tokens_per_expert = expert_mask.sum(dim=(1, 2)) - - # 2. Async setup for input splits. We compute on GPU and hold the tensor, - # deliberately delaying .tolist() which forces CPU blocks, so the GPU can overlap. - input_splits_tensor = num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(dim=1) - - # 3. Setup TK Parallel Tensors (Symmetric memory buffers for peer-to-peer copies) - r, c, padded_rest = _padded_row_col(num_experts) - input_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (1, ep_size, r, c), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=ep_size) - - # Push to symmetric allocation in optimized bfloat16 - padded = torch.zeros(padded_rest, dtype=torch.bfloat16, device=expert_mask.device) - padded[:num_experts] = num_local_tokens_per_expert.to(torch.bfloat16) - input_tk.data_.view(-1)[:padded_rest].copy_(padded) - - # 4. Asynchronous peer-to-peer gathering using TMA natively on device - ext.tk_all_gather(output_tk, input_tk, barrier_tk) - - # 5. Extract flat counts and recover original datatype - out_flat = output_tk.data_.view(ep_size, padded_rest)[:, :num_experts].contiguous() - num_global_tokens_per_expert = out_flat.to(num_local_tokens_per_expert.dtype) - - # 6. Extract bounds for local experts handling - start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts - num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, start_idx:end_idx].contiguous() - - output_splits_tensor = num_global_tokens_per_local_expert.sum(dim=1) - - # Launch non-blocking CPU transfers asynchronously - num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=0).to( - "cpu", non_blocking=True - ) - num_global_tokens_per_local_expert_cpu = num_global_tokens_per_local_expert.view(-1, num_local_experts).to( - "cpu", non_blocking=True - ) - - # 7. Drain and sync everything exactly at the end. At this point the GPU TK TMA kernel - # and local math likely already finished underneath the Python CPU execution time. - input_splits = input_splits_tensor.tolist() - output_splits = output_splits_tensor.tolist() - - return input_splits, output_splits, num_global_tokens_per_local_expert_cpu, num_global_sum_tokens_per_local_expert \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_parallelkittens.py deleted file mode 100755 index 7d00a67..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_parallelkittens.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -Strategy: We exploit device-side P2P PULL via TKParallelTensor to bypass NCCL and avoid any CPU overhead. -1. We allocate a persistent TKParallelTensor for the inputs and an 8-element split sizes array, pre-exchanging NVLink handles. -2. We map a device-side sizes-gather and a vectorized PULL kernel to the default stream, executing purely on the device. -3. Each block is assigned to an output row, computing its remote offset locally using the gathered sizes and pulling 16-byte (float4) vectorized chunks directly into the dynamically allocated PyTorch output tensor. -4. ThunderKittens barriers wrap the kernel to ensure safe NVLink accesses without CPU-side synchronization. -""" - -import os -from typing import List, Optional, Union - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include - -using namespace kittens; - -struct globals { - static constexpr int NUM_DEVICES = 8; - bf16* input_ptrs[NUM_DEVICES]; - bf16* local_output_ptr; - int* splits_ptrs[NUM_DEVICES]; - int* local_splits; - int hidden_dim; - int dev_idx; - int total_output_rows; - barrier_t barrier; -}; - -__global__ void gather_splits_kernel(globals G) { - // 1 block, 64 threads fetches the 8x8 size matrix into a local array - if (threadIdx.x < 64) { - int i = threadIdx.x / 8; - int j = threadIdx.x % 8; - G.local_splits[threadIdx.x] = G.splits_ptrs[i][j]; - } -} - -__global__ void all_to_all_kernel(globals G) { - int row = blockIdx.x; - if (row >= G.total_output_rows) return; - - // Find which remote rank this row comes from - int src_rank = G.NUM_DEVICES - 1; - int row_offset_in_chunk = row; - for (int i = 0; i < G.NUM_DEVICES; ++i) { - int size = G.local_splits[i * G.NUM_DEVICES + G.dev_idx]; - if (row_offset_in_chunk < size) { - src_rank = i; - break; - } - row_offset_in_chunk -= size; - } - - // Compute the base row offset on the remote rank - int src_base_row = 0; - for (int j = 0; j < G.dev_idx; ++j) { - src_base_row += G.local_splits[src_rank * G.NUM_DEVICES + j]; - } - - int src_row = src_base_row + row_offset_in_chunk; - - bf16* src_row_ptr = G.input_ptrs[src_rank] + src_row * G.hidden_dim; - bf16* dst_row_ptr = G.local_output_ptr + row * G.hidden_dim; - - // Vectorized copy using float4 (16 bytes = 8 bf16s) if perfectly aligned - if (G.hidden_dim % 8 == 0) { - float4* s = reinterpret_cast(src_row_ptr); - float4* d = reinterpret_cast(dst_row_ptr); - int cols = G.hidden_dim / 8; - for (int i = threadIdx.x; i < cols; i += blockDim.x) { - d[i] = s[i]; - } - } else if (G.hidden_dim % 4 == 0) { - float2* s = reinterpret_cast(src_row_ptr); - float2* d = reinterpret_cast(dst_row_ptr); - int cols = G.hidden_dim / 4; - for (int i = threadIdx.x; i < cols; i += blockDim.x) { - d[i] = s[i]; - } - } else if (G.hidden_dim % 2 == 0) { - float* s = reinterpret_cast(src_row_ptr); - float* d = reinterpret_cast(dst_row_ptr); - int cols = G.hidden_dim / 2; - for (int i = threadIdx.x; i < cols; i += blockDim.x) { - d[i] = s[i]; - } - } else { - // Fallback for odd hidden dims - for (int i = threadIdx.x; i < G.hidden_dim; i += blockDim.x) { - dst_row_ptr[i] = src_row_ptr[i]; - } - } -} - -__global__ void barrier_kernel(globals G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -void entrypoint( - kittens::py::TKParallelTensor &input_tk, - torch::Tensor &output_tensor, - kittens::py::TKParallelTensor &splits_tk, - kittens::py::TKParallelTensor &barrier_tk, - torch::Tensor &local_splits, - int hidden_dim, - int total_output_rows -) { - globals G; - G.hidden_dim = hidden_dim; - G.dev_idx = input_tk.local_rank_; - G.total_output_rows = total_output_rows; - G.local_output_ptr = reinterpret_cast(output_tensor.data_ptr()); - G.local_splits = local_splits.data_ptr(); - - for (int i = 0; i < globals::NUM_DEVICES; ++i) { - G.input_ptrs[i] = reinterpret_cast(input_tk.ptrs_[i]); - G.splits_ptrs[i] = reinterpret_cast(splits_tk.ptrs_[i]); - } - - G.barrier = kittens::py::parallel_tensor_to_pgl>(barrier_tk); - - // Ensure all peers have completed their host->device split sizes and D2D tensor copies - barrier_kernel<<<1, 1>>>(G); - - // Pull the 8x8 splits matrix into local GPU memory for fast access by the blocks - gather_splits_kernel<<<1, 64>>>(G); - - // Launch one block per output row - int num_blocks = total_output_rows; - if (num_blocks > 0) { - int threads_per_block = 256; - all_to_all_kernel<<>>(G); - } - - // Sync to ensure everyone is done pulling from our input_tk before proceeding - barrier_kernel<<<1, 1>>>(G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_moe_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -# Fallback allocation guard: allocate up to 128M elements (~256MB) persistently for input payloads -MAX_ELEMS = 128 * 1024 * 1024 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_moe_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def solution( - local_tensor: torch.Tensor, - input_split_sizes: Optional[Union[List[int], torch.Tensor]] = None, - output_split_sizes: Optional[Union[List[int], torch.Tensor]] = None, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return local_tensor.contiguous() - - local_tensor = local_tensor.contiguous() - hidden_dim = local_tensor.size(1) - - if output_split_sizes is None: - out_size = local_tensor.size(0) - else: - out_size = sum(output_split_sizes) if isinstance(output_split_sizes, list) else int(output_split_sizes.sum().item()) - - output = torch.empty( - (out_size, hidden_dim), - dtype=local_tensor.dtype, - device=local_tensor.device, - ) - - # Check alignment + bounds; fallback dynamically if size or environment deviates - if (world_size != 8 or - local_tensor.dtype != torch.bfloat16 or - local_tensor.numel() > MAX_ELEMS): - dist.all_to_all_single( - output, - local_tensor, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - ) - return output - - ext = _ensure_ext_jit() - - if input_split_sizes is None: - in_splits = torch.tensor([local_tensor.size(0) // world_size] * world_size, dtype=torch.int32, device=local_tensor.device) - elif isinstance(input_split_sizes, list): - in_splits = torch.tensor(input_split_sizes, dtype=torch.int32, device=local_tensor.device) - else: - in_splits = input_split_sizes.to(torch.int32).to(local_tensor.device) - - # Request our persistent symmetric memory blocks - input_tk = get_or_create_parallel_tensor(ext, (MAX_ELEMS,), torch.bfloat16, multicast=False) - splits_tk = get_or_create_parallel_tensor(ext, (8,), torch.int32, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - local_splits = torch.empty((8, 8), dtype=torch.int32, device=local_tensor.device) - - # Queue input tensor copies async onto the default stream - numel_in = local_tensor.numel() - if numel_in > 0: - input_tk.data_[:numel_in].copy_(local_tensor.view(-1)) - - splits_tk.data_[:world_size].copy_(in_splits) - - # Launch fused P2P PULL kernel - ext.tk_moe_all_to_all( - input_tk, - output, - splits_tk, - barrier_tk, - local_splits, - hidden_dim, - out_size - ) - - return output \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_parallelkittens.py deleted file mode 100755 index 1ee670c..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_parallelkittens.py +++ /dev/null @@ -1,288 +0,0 @@ -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for fused permute + P2P scatter -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace fused_scatter { - -using output_layout = pgl, 8, false>; - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 128; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - output_layout output; - const __nv_bfloat16* hidden_states; - const int64_t* sorted_indices; - const int64_t* sorted_expert_ids; - const int32_t* my_dest_offsets; - const int32_t* my_src_offsets; - int num_local_experts; - int hidden_dim; - int total_routed; - - __host__ inline dim3 grid() const { - return dim3(total_routed); - } -}; - -__device__ inline void kernel(const globals &G) { - int token_idx = blockIdx.x; - if (token_idx >= G.total_routed) return; - - int64_t src_token_id = G.sorted_indices[token_idx]; - int64_t expert_id = G.sorted_expert_ids[token_idx]; - - int owner_rank = expert_id / G.num_local_experts; - int relative_idx = token_idx - G.my_src_offsets[expert_id]; - int dest_token_id = G.my_dest_offsets[expert_id] + relative_idx; - - const __nv_bfloat16* src_row = G.hidden_states + src_token_id * G.hidden_dim; - __nv_bfloat16* dest_row = G.output[owner_rank].data + dest_token_id * G.hidden_dim; - - // Vectorized copy via 128-bit float4 (8 __nv_bfloat16 elements) - int vec_len = G.hidden_dim / 8; - const float4* src_vec = reinterpret_cast(src_row); - float4* dest_vec = reinterpret_cast(dest_row); - - for (int i = threadIdx.x; i < vec_len; i += blockDim.x) { - dest_vec[i] = src_vec[i]; - } - - // Remainder - int remainder_start = vec_len * 8; - for (int i = remainder_start + threadIdx.x; i < G.hidden_dim; i += blockDim.x) { - dest_row[i] = src_row[i]; - } -} - -} // namespace fused_scatter - -namespace barrier_ns { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; - }; - __device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); - } -} - -void entrypoint( - torch::Tensor hidden_states, - torch::Tensor sorted_indices, - torch::Tensor sorted_expert_ids, - torch::Tensor my_dest_offsets, - torch::Tensor my_src_offsets, - kittens::py::TKParallelTensor &output_tk, - kittens::py::TKParallelTensor &barrier_tk, - int num_local_experts -) { - int total_routed = sorted_indices.size(0); - int hidden_dim = hidden_states.size(1); - - fused_scatter::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output_tk), - .hidden_states = reinterpret_cast(hidden_states.data_ptr()), - .sorted_indices = sorted_indices.data_ptr(), - .sorted_expert_ids = sorted_expert_ids.data_ptr(), - .my_dest_offsets = my_dest_offsets.data_ptr(), - .my_src_offsets = my_src_offsets.data_ptr(), - .num_local_experts = num_local_experts, - .hidden_dim = hidden_dim, - .total_routed = total_routed - }; - - barrier_ns::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier_tk), - .dev_idx = barrier_tk.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - if (total_routed > 0) { - kittens::py::launch_kernel(G); - } - kittens::py::launch_kernel(barrier_G); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_scatter", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_fused_moe_scatter_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - hidden_states: torch.Tensor, - expert_mask: torch.Tensor, - num_experts: int, - input_splits: Union[List[int], torch.Tensor], - output_splits: Union[List[int], torch.Tensor], - num_global_tokens_per_local_expert: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size]: - device = hidden_states.device - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - my_rank = dist.get_rank(group) - - assert world_size == 8, "This ThunderKittens parallel layout specifies NUM_DEVICES=8" - - ext = _ensure_ext_jit() - - hidden_dim = hidden_states.size(-1) - org_hidden_states_shape = hidden_states.shape - original_dtype = hidden_states.dtype - - hidden_states_2d = hidden_states.to(torch.bfloat16).view(-1, hidden_dim).contiguous() - num_tokens = hidden_states_2d.size(0) - - # Fast token routing & mask extraction using built-ins - routing_map = expert_mask.sum(dim=1) - routing_map_bool = routing_map.bool() - - token_indices = torch.arange(num_tokens, device=device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map_bool) - - expert_indices = torch.arange(num_experts, device=device).unsqueeze(1).expand(-1, num_tokens) - sorted_expert_ids = expert_indices.masked_select(routing_map_bool) - - expected_tokens = sum(input_splits) if isinstance(input_splits, list) else int(input_splits.sum().item()) - actual_tokens = sorted_indices.size(0) - if expected_tokens != actual_tokens: - raise RuntimeError( - f"EP split mismatch: input_splits sum ({expected_tokens}) != permuted tokens ({actual_tokens})" - ) - - # 1. Exchange the tiny expert token counts to compute perfect destination buffer offsets natively - my_tokens_per_expert = routing_map_bool.sum(dim=1).to(torch.int32) - global_tokens_per_expert_flat = torch.empty(world_size * num_experts, dtype=torch.int32, device=device) - dist.all_gather_into_tensor(global_tokens_per_expert_flat, my_tokens_per_expert, group=group) - global_tokens_per_expert = global_tokens_per_expert_flat.view(world_size, num_experts) - - num_local_experts = num_experts // world_size - - # Calculate absolute offset per expert within the receiver's target buffer - expert_totals = global_tokens_per_expert.sum(dim=0) - expert_totals_2d = expert_totals.view(world_size, num_local_experts) - expert_base_offsets_2d = torch.zeros_like(expert_totals_2d) - expert_base_offsets_2d[:, 1:] = expert_totals_2d[:, :-1].cumsum(dim=1) - expert_base_offsets = expert_base_offsets_2d.view(num_experts) - - if my_rank > 0: - my_rank_offset_within_expert = global_tokens_per_expert[:my_rank, :].sum(dim=0) - else: - my_rank_offset_within_expert = torch.zeros(num_experts, dtype=torch.int32, device=device) - - # The exact placement offset our SM will scatter into the peer's SM symmetric buffer - my_dest_offsets = expert_base_offsets + my_rank_offset_within_expert - - my_src_offsets = torch.zeros(num_experts, dtype=torch.int32, device=device) - my_src_offsets[1:] = my_tokens_per_expert[:-1].cumsum(dim=0) - - # 2. Setup symmetrical receive buffers (padding max size consistently across group) - local_out_size = int(sum(output_splits) if isinstance(output_splits, list) else output_splits.sum().item()) - local_out_size_t = torch.tensor([local_out_size], dtype=torch.int32, device=device) - max_out_size_t = local_out_size_t.clone() - dist.all_reduce(max_out_size_t, op=dist.ReduceOp.MAX, group=group) - max_out_size = max(1, max_out_size_t.item()) - - # We use (1, 1, R, C) layout mapping for a generic pgl tensor match - output_tk = get_or_create_parallel_tensor( - ext, (1, 1, max_out_size, hidden_dim), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # 3. Direct NVLink P2P DMA dispatch bypassing chunk assembly, serialization, and concatenation splits - ext.tk_fused_scatter( - hidden_states_2d, - sorted_indices, - sorted_expert_ids, - my_dest_offsets, - my_src_offsets, - output_tk, - barrier_tk, - num_local_experts - ) - - global_permuted = output_tk.data_.view(-1, hidden_dim)[:local_out_size].to(original_dtype).clone() - - return global_permuted, routing_map, sorted_indices, org_hidden_states_shape \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_parallelkittens.py deleted file mode 100755 index f1895f4..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_parallelkittens.py +++ /dev/null @@ -1,325 +0,0 @@ -""" -ThunderKittens MoE Post-All2All using P2P direct pull and Fused Unpermute. - -Strategy: -1. Device-side P2P Communication: Instead of using `dist.all_to_all_single` and allocating intermediate receive buffers, ranks expose their sorted outgoing tokens in a symmetric `TKParallelTensor` send buffer. -2. Operator Fusion: We replace the memory-bound `scatter_add_` unpermute by having each rank directly pull its assigned tokens from peers over NVLink, multiply by the routing weights, and atomically add into the final output in a single fused Hopper kernel. -3. Compute-Communication Overlap: The small host-side metadata exchange (prefix sums of send chunks) is overlapped asynchronously with the local token sorting phase. -""" - -import os -from typing import List, Optional, Union - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Fused Pull and Unpermute -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace fused_pull { - -struct globals { - static constexpr int NUM_DEVICES = 8; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout send_buffer; - int* peer_send_offsets; - int* recv_offsets; - __nv_bfloat16* tokens_weight; - int64_t* permutation_mapping; - __nv_bfloat16* unpermuted_tokens; - - int hidden_dim; - int num_received_tokens; -}; - -__global__ void kernel(globals G) { - int token_idx = blockIdx.x; - if (token_idx >= G.num_received_tokens) return; - - // Find which rank this token comes from based on the received prefix sums - int src_rank = 0; - for (int r = 1; r < globals::NUM_DEVICES; ++r) { - if (token_idx >= G.recv_offsets[r]) { - src_rank = r; - } - } - - int local_offset = token_idx - G.recv_offsets[src_rank]; - int peer_offset = G.peer_send_offsets[src_rank] + local_offset; - - // Direct P2P read from the peer's symmetric send buffer - kittens::bf16* base_ptr = G.send_buffer[src_rank].data; - __nv_bfloat16* src_ptr = reinterpret_cast<__nv_bfloat16*>(base_ptr) + peer_offset * G.hidden_dim; - - int64_t dest_idx = G.permutation_mapping[token_idx]; - __nv_bfloat16* dst_ptr = &G.unpermuted_tokens[dest_idx * G.hidden_dim]; - - float weight_f = __bfloat162float(G.tokens_weight[token_idx]); - - int tid = threadIdx.x; - int num_float4 = G.hidden_dim / 8; // float4 reads 16 bytes = 8 x bfloat16 - float4* src_ptr_4 = reinterpret_cast(src_ptr); - - for (int i = tid; i < num_float4; i += blockDim.x) { - float4 val4 = src_ptr_4[i]; - __nv_bfloat16* vals = reinterpret_cast<__nv_bfloat16*>(&val4); - - #pragma unroll - for (int j = 0; j < 8; ++j) { - float v = __bfloat162float(vals[j]); - float res = v * weight_f; - atomicAdd(&dst_ptr[8 * i + j], __float2bfloat16(res)); - } - } - - // Handle remainder if hidden_dim is not a multiple of 8 - for (int idx = num_float4 * 8 + tid; idx < G.hidden_dim; idx += blockDim.x) { - float v = __bfloat162float(src_ptr[idx]); - float res = v * weight_f; - atomicAdd(&dst_ptr[idx], __float2bfloat16(res)); - } -} - -} // namespace fused_pull - -namespace all_reduce_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &send_buffer, - kittens::py::TKParallelTensor &barrier, - torch::Tensor peer_send_offsets, - torch::Tensor recv_offsets, - torch::Tensor tokens_weight, - torch::Tensor permutation_mapping, - torch::Tensor unpermuted_tokens, - int hidden_dim, - int num_received_tokens -) { - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // First barrier: wait for all ranks to populate their send buffers - kittens::py::launch_kernel(barrier_G); - - fused_pull::globals G; - G.send_buffer = kittens::py::parallel_tensor_to_pgl(send_buffer); - G.peer_send_offsets = peer_send_offsets.data_ptr(); - G.recv_offsets = recv_offsets.data_ptr(); - G.tokens_weight = reinterpret_cast<__nv_bfloat16*>(tokens_weight.data_ptr()); - G.permutation_mapping = permutation_mapping.data_ptr(); - G.unpermuted_tokens = reinterpret_cast<__nv_bfloat16*>(unpermuted_tokens.data_ptr()); - G.hidden_dim = hidden_dim; - G.num_received_tokens = num_received_tokens; - - if (num_received_tokens > 0) { - // One block per received token - fused_pull::kernel<<>>(G); - } - - // Second barrier: wait until everyone has read from our send buffer before we safely exit - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_pull_unpermute", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_fused_moe_pull_unpermute", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _sort_chunks_by_idxs( - input: torch.Tensor, - split_sizes: Union[torch.Tensor, List[int]], - sorted_idxs: List[int], -) -> torch.Tensor: - if isinstance(split_sizes, torch.Tensor): - split_sizes = split_sizes.tolist() - chunks = torch.split(input, split_sizes, dim=0) - return torch.cat([chunks[i] for i in sorted_idxs], dim=0) - - -def _generate_weights_idx( - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, -) -> torch.Tensor: - num_tokens, topk = routing_weights.shape - weights_idx = torch.zeros( - (num_tokens, num_experts), dtype=routing_weights.dtype, device=routing_weights.device - ) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - return weights_idx - - -@torch.no_grad() -def solution( - expert_outputs: torch.Tensor, - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, - input_splits: Union[List[int], torch.Tensor], - output_splits: Union[List[int], torch.Tensor], - num_global_tokens_per_local_expert: torch.Tensor, - routing_map: torch.Tensor, - local_input_permutation_mapping: torch.Tensor, - org_hidden_states_shape: torch.Size, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - assert world_size == 8, "This ThunderKittens kernel is strictly built for 8 GPUs." - - rank = dist.get_rank(group) - device = expert_outputs.device - ext = _ensure_ext_jit() - - # 1. Asynchronously gather the chunk sizes to map out peer data structures - if isinstance(output_splits, torch.Tensor): - send_splits_tensor = output_splits.to(dtype=torch.int32, device=device) - else: - send_splits_tensor = torch.tensor(output_splits, dtype=torch.int32, device=device) - - all_send_splits = torch.empty((world_size, world_size), dtype=torch.int32, device=device) - gather_list = [all_send_splits[i] for i in range(world_size)] - handle = dist.all_gather(gather_list, send_splits_tensor, group=group, async_op=True) - - # 2. Sort the locally processed expert outputs while the host `all_gather` progresses - num_local_experts = num_experts // world_size - unpermute_order = torch.arange(num_experts).reshape(num_local_experts, -1).T.ravel().tolist() - - sorted_expert_outputs = _sort_chunks_by_idxs( - expert_outputs, - num_global_tokens_per_local_expert.T.ravel(), - unpermute_order, - ) - - # 3. Secure Symmetric Buffer Allocation - hidden_dim = org_hidden_states_shape[-1] - topk = routing_weights.size(1) - # The max tokens a rank could possibly process across ALL experts globally - max_tokens = org_hidden_states_shape[0] * world_size * topk - - send_buffer_tk = get_or_create_parallel_tensor( - ext, (1, 1, max_tokens, hidden_dim), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - num_elements = sorted_expert_outputs.size(0) - if num_elements > 0: - send_buffer_tk.data_[0, 0, :num_elements, :].copy_(sorted_expert_outputs.to(torch.bfloat16)) - - handle.wait() - - # 4. Resolve Peer offsets mapping - peer_send_offsets = torch.zeros((world_size, world_size), dtype=torch.int32, device=device) - peer_send_offsets[:, 1:] = torch.cumsum(all_send_splits, dim=1)[:, :-1] - - my_recv_splits = all_send_splits[:, rank].contiguous() - recv_offsets = torch.zeros(world_size, dtype=torch.int32, device=device) - recv_offsets[1:] = torch.cumsum(my_recv_splits, dim=0)[:-1] - - num_received_tokens = my_recv_splits.sum().item() - - # 5. Extract strictly necessary weighting scalars mapping to the incoming stream - weights_idx = _generate_weights_idx(routing_weights, selected_experts, num_experts) - tokens_weight = weights_idx.T.contiguous().masked_select(routing_map.bool()) - - unpermuted_tokens = torch.zeros(org_hidden_states_shape, dtype=torch.bfloat16, device=device) - local_input_permutation_mapping = local_input_permutation_mapping.to(torch.int64).contiguous() - - # 6. Fire Unified Fused Pull-and-Unpermute over NVLink - ext.tk_fused_pull_unpermute( - send_buffer_tk, - barrier_tk, - peer_send_offsets[:, rank].contiguous(), - recv_offsets.contiguous(), - tokens_weight.to(torch.bfloat16), - local_input_permutation_mapping, - unpermuted_tokens, - hidden_dim, - num_received_tokens - ) - - return unpermuted_tokens \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_parallelkittens.py deleted file mode 100755 index f4867db..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_parallelkittens.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -Strategy: -- **Device-side NVLink Multicast**: By compiling with ThunderKittens' `pgl` layout and fetching the `mc_ptr`, we utilize Hopper's native NVLink multicast to perform all-gather as an O(1) store operation per element. -- **Zero Host Overhead / Custom Scheduling**: We bypass stock NCCL and its host-side dispatch latency. Each GPU just loads its local tensor chunk and fires a 16-byte vectorized cache-bypassing store directly into the multicast fabric. -- **Hardware-accelerated Output Delivery**: The hardware switch broadcasts these writes to the symmetric physical offset in the VMM-allocated result tensor on all peers simultaneously, fully saturating the node's bisection bandwidth without an explicit software tree/ring. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_gather entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 8; // 8 bf16s = 16 bytes = uint4 - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - // Output mapped across all ranks with multicast enabled - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout output; - const bf16* input; - const int dev_idx; - - __host__ inline dim3 grid() const { - size_t chunk_size = output.numel() / NUM_DEVICES; - int blocks = (chunk_size + NUM_ELEMS_PER_BLOCK - 1) / NUM_ELEMS_PER_BLOCK; - if (blocks == 0) blocks = 1; - return dim3(blocks); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t chunk_size = G.output.numel() / globals::NUM_DEVICES; - const size_t idx = globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + globals::NUM_ELEMS_PER_INST * threadIdx.x; - - if (idx < chunk_size) { - // Vectorized 16-byte load from local input - uint4 tmp = *reinterpret_cast(&G.input[idx]); - - // Target physical offset in the assembled tensor - const size_t out_idx = chunk_size * G.dev_idx + idx; - - // Cache-bypassing (.cg) multicast store via inline PTX - asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};" - :: "l"(reinterpret_cast(&G.output.mc_ptr[out_idx])), - "r"(tmp.x), "r"(tmp.y), "r"(tmp.z), "r"(tmp.w) : "memory"); - } -} - -} // namespace all_gather - -namespace all_gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - const torch::Tensor &input, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, barrier); - TORCH_CHECK(input.is_contiguous(), "Input must be contiguous"); - - all_gather::globals all_gather_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = reinterpret_cast(input.data_ptr()), - .dev_idx = output.local_rank_ - }; - - all_gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronization barrier before mapping phase (if overlapping usages exist) - kittens::py::launch_kernel(barrier_G); - - // Core payload: parallel NVLink multicasts - kittens::py::launch_kernel(all_gather_G); - - // Trailing barrier ensures completion of network-bound cache bypass stores globally - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 -NUM_ELEMS_PER_INST = 8 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST # 2048 -ALIGNMENT = NUM_ELEMS_PER_BLOCK - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is scaled for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - n = flat.numel() - - # Align chunk size to 16-byte bounds safely handled by 2048-elem blocks - padded_n = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - if padded_n == 0: - padded_n = ALIGNMENT - - # VMM allocation cached across calls; maps the symmetric target layout for multicast writes - output_tk = get_or_create_parallel_tensor( - ext, (world, padded_n), torch.bfloat16, multicast=True - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - if padded_n > n: - padded_inp = torch.zeros(padded_n, dtype=torch.bfloat16, device=tensor.device) - padded_inp[:n] = flat - else: - padded_inp = flat - - # Dispatch device kernel. Returns when symmetric multicast is globally visible. - ext.tk_all_gather(output_tk, padded_inp, barrier_tk) - - # Slice out precisely the logical layout natively on the current device - out_flat = output_tk.data_.view(world, padded_n)[:, :n].clone() - - return out_flat.to(original_dtype).reshape((world,) + original_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_parallelkittens.py deleted file mode 100755 index a75d6be..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_parallelkittens.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -Strategy: -- **Flatten & Fuse:** Instead of three separate `dist.all_reduce` calls, we pack the three LoRA gradient tensors into a single contiguous buffer. This mitigates repeated kernel and collective launch overheads. -- **Device-Side Communication:** We deploy ThunderKittens' `pgl` layout with NVSwitch multicast via symmetric memory (`TKParallelTensor`). -- **In-Network Reduction:** Our custom CUDA kernel exploits Hopper's `multimem::ld_reduce` for a direct, low-latency sum across all 8 devices on the node—drastically shrinking the footprint compared to a stock NCCL hot path. -""" - -import os -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Fused Multimem All-Reduce -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_lora_allreduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - if dist.is_initialized() and dist.get_rank() == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - grad_fc1_1_lora_A: torch.Tensor, - grad_fc1_2_lora_A: torch.Tensor, - grad_fc2_lora_B: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not dist.is_initialized(): - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B - - world = dist.get_world_size(group) - - # Fallback to pure NCCL if not on a full 8-GPU node arrangement - if world != NUM_DEVICES: - group = group or dist.group.WORLD - dist.all_reduce(grad_fc1_1_lora_A, op=dist.ReduceOp.SUM, group=group) - dist.all_reduce(grad_fc1_2_lora_A, op=dist.ReduceOp.SUM, group=group) - dist.all_reduce(grad_fc2_lora_B, op=dist.ReduceOp.SUM, group=group) - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B - - ext = _ensure_ext_jit() - - # Create 1D views to fuse into a single contiguous TK array - f1 = grad_fc1_1_lora_A.view(-1) - f2 = grad_fc1_2_lora_A.view(-1) - f3 = grad_fc2_lora_B.view(-1) - - n1, n2, n3 = f1.numel(), f2.numel(), f3.numel() - n = n1 + n2 + n3 - - # Pad out to mults of (NUM_DEVICES * block_alignment) - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Coalesced scatter into symmetric memory - tensor_tk.data_[:n1].copy_(f1) - tensor_tk.data_[n1 : n1 + n2].copy_(f2) - tensor_tk.data_[n1 + n2 : n].copy_(f3) - if padded > n: - tensor_tk.data_[n:].zero_() - - # Single-shot all-reduce kernel (barrier -> multimem reduce/store -> barrier) - ext.tk_all_reduce(tensor_tk, barrier_tk) - - # In-place gather - grad_fc1_1_lora_A.copy_(tensor_tk.data_[:n1].view(grad_fc1_1_lora_A.shape)) - grad_fc1_2_lora_A.copy_(tensor_tk.data_[n1 : n1 + n2].view(grad_fc1_2_lora_A.shape)) - grad_fc2_lora_B.copy_(tensor_tk.data_[n1 + n2 : n].view(grad_fc2_lora_B.shape)) - - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_parallelkittens.py deleted file mode 100755 index 6b08621..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_parallelkittens.py +++ /dev/null @@ -1,553 +0,0 @@ -""" -Strategy: -1. Custom TMA-based All-to-All: Replaced NCCL's `dist.all_to_all_single` with a custom ThunderKittens - kernel `tk_jagged_all_to_all`. It uses asynchronous TMA loads/stores to move variable-sized token - chunks directly between GPUs via symmetric memory (PGL). -2. Dynamic Tile Loading: Extracted the global `max_tokens` from the existing `_preprocess` step, - allowing us to allocate stable TKParallelTensors. The host passes a `send_tiles` array to the kernel - so it only TMA transfers valid data chunks, minimizing NVLink bandwidth waste. -3. Device-Side Synchronization: Fused `barrier_all` into the entrypoint to guarantee memory visibility - across the cluster without stalling the host. -4. Fast Packing/Casting: Pack/unpack steps use fast device-to-device `.copy_()`, which seamlessly - handles FP32 <-> BF16 conversions while staging data for the asynchronous PGL TMA transfers. -""" - -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Jagged TMA All-To-All -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace jagged_all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 64; - - using shared_tile = st_bf; - // Layout: [batch, depth, rows, cols] - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - int send_tiles[NUM_DEVICES]; - - __host__ inline dim3 grid() const { - return dim3(NUM_DEVICES, input.rows() / ROW_BLOCK_SIZE); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - int dst_dev = blockIdx.x; - int tile_idx = blockIdx.y; - - if (tile_idx >= G.send_tiles[dst_dev]) return; - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - - // Load from local input buffer: [dst_dev, 0, tile_idx, 0] - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {dst_dev, 0, tile_idx, 0}, arrived); - wait(arrived, 0); - - // Store to remote output buffer on dst_dev: [0, G.dev_idx, tile_idx, 0] - tma::store_async(G.output[dst_dev], tile, {0, G.dev_idx, tile_idx, 0}); -} - -} // namespace jagged_all_to_all - -namespace jagged_all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace jagged_all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - std::vector send_tiles -) { - kittens::py::parallel_tensor_check(output, input); - - jagged_all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - for(int i = 0; i < jagged_all_to_all::globals::NUM_DEVICES; i++) { - all_to_all_G.send_tiles[i] = send_tiles[i]; - } - - jagged_all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(all_to_all_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_jagged_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", "--use_fast_math", "--expt-extended-lambda", "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", "-Xcompiler=-fno-strict-aliasing", "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_jagged_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -# ----- Custom TK AllToAll Autograd Function ----- - -class TKJaggedAllToAll(torch.autograd.Function): - @staticmethod - def forward(ctx, group, input, output_split_sizes, input_split_sizes, max_tokens, ext): - ctx.group = group - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.max_tokens = max_tokens - ctx.ext = ext - - W = dist.get_world_size(group=group) - if W == 1: - return input.contiguous() - - input = input.contiguous() - H = input.shape[-1] - - assert W == 8, "This ThunderKittens kernel is compiled for NUM_DEVICES=8" - assert H == 64, "This ThunderKittens kernel is compiled for COL_BLOCK_SIZE=64" - - ROW_BLOCK_SIZE = 16 - # Compute padding boundaries across all messages - MAX_TILES = max(1, (max_tokens + ROW_BLOCK_SIZE - 1) // ROW_BLOCK_SIZE) - PADDED_LEN = MAX_TILES * ROW_BLOCK_SIZE - - # Asymmetric shape prevents PGL aliasing and maps correctly to TMA coordinates - # input_tk: [batch=W, depth=1] -> local read [dst_dev, 0, tile_idx, 0] - # output_tk: [batch=1, depth=W] -> remote write [0, src_dev, tile_idx, 0] - input_tk = get_or_create_parallel_tensor(ext, (W, 1, PADDED_LEN, H), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (1, W, PADDED_LEN, H), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=W) - - # Pack device-side splits dynamically (and implicitly cast FP32 -> BF16) - offset = 0 - send_tiles = [] - for j in range(W): - length = input_split_sizes[j] - if length > 0: - input_tk.data_[j, 0, :length, :].copy_(input[offset : offset + length]) - offset += length - send_tiles.append((length + ROW_BLOCK_SIZE - 1) // ROW_BLOCK_SIZE) - - # Execute ThunderKittens workload - ext.tk_jagged_all_to_all(output_tk, input_tk, barrier_tk, send_tiles) - - # Unpack remote chunks (and implicitly cast BF16 -> FP32) - output = torch.empty((sum(output_split_sizes), H), dtype=input.dtype, device=input.device) - offset = 0 - for j in range(W): - length = output_split_sizes[j] - if length > 0: - output[offset : offset + length].copy_(output_tk.data_[0, j, :length, :]) - offset += length - - return output - - @staticmethod - def backward(ctx, grad_output): - # Swaps splits symmetrically - grad_input = TKJaggedAllToAll.apply( - ctx.group, grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.max_tokens, ctx.ext - ) - return None, grad_input, None, None, None, None - - -# ----- MoE Operations ----- - -def _preprocess( - expert_mask: torch.Tensor, - num_experts: int, - ep_group: dist.ProcessGroup, -) -> Tuple[List[int], List[int], torch.Tensor, torch.Tensor, int]: - ep_size = ep_group.size() - num_local_experts = num_experts // ep_size - rank = dist.get_rank(ep_group) - - num_local_tokens_per_expert = expert_mask.sum(dim=(1, 2)) - input_splits = ( - num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(dim=1).tolist() - ) - - num_local_tokens_per_expert_flat = num_local_tokens_per_expert.contiguous().view(-1) - output_size = ep_size * num_local_tokens_per_expert_flat.numel() - num_global_tokens_per_expert_flat = torch.empty( - output_size, - dtype=num_local_tokens_per_expert.dtype, - device=num_local_tokens_per_expert.device, - ) - dist.all_gather_into_tensor( - num_global_tokens_per_expert_flat, num_local_tokens_per_expert_flat, group=ep_group - ) - - num_global_tokens_per_expert = num_global_tokens_per_expert_flat.view( - ep_size, num_local_tokens_per_expert.size(0) - ) - start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts - num_global_tokens_per_local_expert = num_global_tokens_per_expert[ - :, start_idx:end_idx - ].contiguous() - output_splits = num_global_tokens_per_local_expert.sum(dim=1).tolist() - - # Calculate exactly what max_tokens needs to be via the global map - send_matrix = torch.zeros((ep_size, ep_size), dtype=torch.int32, device=expert_mask.device) - for j in range(ep_size): - send_matrix[:, j] = num_global_tokens_per_expert[:, j * num_local_experts : (j + 1) * num_local_experts].sum(dim=1) - max_tokens = send_matrix.max().item() - - num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( - dim=0 - ).to(torch.device("cpu"), non_blocking=True) - num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view( - -1, num_local_experts - ).to(torch.device("cpu"), non_blocking=True) - - return ( - input_splits, - output_splits, - num_global_tokens_per_local_expert, - num_global_sum_tokens_per_local_expert, - max_tokens - ) - - -def _permute( - tokens: torch.Tensor, routing_map: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - num_tokens, _ = tokens.shape - num_experts = routing_map.shape[0] - routing_map = routing_map.bool() - token_indices = ( - torch.arange(num_tokens, device=routing_map.device) - .unsqueeze(0) - .expand(num_experts, -1) - ) - sorted_indices = token_indices.masked_select(routing_map) - permuted_input = tokens.index_select(0, sorted_indices) - return permuted_input, sorted_indices - - -def _sort_chunks_by_idxs( - input: torch.Tensor, - split_sizes: Union[torch.Tensor, List[int]], - sorted_idxs: List[int], -) -> torch.Tensor: - if isinstance(split_sizes, torch.Tensor): - split_sizes = split_sizes.tolist() - chunks = torch.split(input, split_sizes, dim=0) - return torch.cat([chunks[i] for i in sorted_idxs], dim=0) - - -def _generate_weights_idx( - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, -) -> torch.Tensor: - num_tokens, topk = routing_weights.shape - weights_idx = torch.zeros( - (num_tokens, num_experts), - dtype=routing_weights.dtype, - device=routing_weights.device, - ) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - return weights_idx - - -def _unpermute( - tokens: torch.Tensor, - routing_weights: torch.Tensor, - hidden_states_shape: torch.Size, - permutation_mapping: torch.Tensor, - routing_map: torch.Tensor, -) -> torch.Tensor: - tokens_weight = routing_weights.T.contiguous().masked_select(routing_map.bool()) - tokens = tokens * tokens_weight.unsqueeze(-1) - hidden_dim = hidden_states_shape[-1] - unpermuted_tokens = torch.zeros( - hidden_states_shape, device=tokens.device, dtype=tokens.dtype - ) - expanded_mapping = permutation_mapping.unsqueeze(1).expand(-1, hidden_dim) - unpermuted_tokens.scatter_add_(0, expanded_mapping, tokens) - return unpermuted_tokens - - -def token_pre_all2all( - hidden_states: torch.Tensor, - expert_mask: torch.Tensor, - num_experts: int, - input_splits: List[int], - output_splits: List[int], - num_global_tokens_per_local_expert: torch.Tensor, - max_tokens: int, - ext, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size]: - group = group or dist.group.WORLD - hidden_dim = hidden_states.size(-1) - hidden_states = hidden_states.reshape(-1, hidden_dim) - org_hidden_states_shape = hidden_states.shape - routing_map = expert_mask.sum(dim=1) - - local_permuted_hidden_states, local_input_permutation_mapping = _permute( - hidden_states, routing_map - ) - expected_tokens = sum(input_splits) - actual_tokens = local_permuted_hidden_states.shape[0] - if expected_tokens != actual_tokens: - raise RuntimeError( - f"EP split mismatch: input_splits sum ({expected_tokens}) != " - f"permuted tokens ({actual_tokens})" - ) - - # Use parallelkittens optimized TMA jagged communication - global_permuted_hidden_states = TKJaggedAllToAll.apply( - group, local_permuted_hidden_states, output_splits, input_splits, max_tokens, ext - ) - - num_local_experts = num_experts // dist.get_world_size(group) - permute_order = ( - torch.arange(num_experts).reshape(-1, num_local_experts).T.ravel().tolist() - ) - split_sizes = num_global_tokens_per_local_expert.ravel().tolist() - global_permuted_hidden_states = _sort_chunks_by_idxs( - global_permuted_hidden_states, split_sizes, permute_order - ) - return ( - global_permuted_hidden_states, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - ) - - -def tokens_post_all2all( - expert_outputs: torch.Tensor, - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, - input_splits: List[int], - output_splits: List[int], - num_global_tokens_per_local_expert: torch.Tensor, - routing_map: torch.Tensor, - local_input_permutation_mapping: torch.Tensor, - org_hidden_states_shape: torch.Size, - max_tokens: int, - ext, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - num_local_experts = num_experts // dist.get_world_size(group) - unpermute_order = ( - torch.arange(num_experts).reshape(num_local_experts, -1).T.ravel().tolist() - ) - split_sizes = num_global_tokens_per_local_expert.T.ravel().tolist() - expert_outputs = _sort_chunks_by_idxs( - expert_outputs, split_sizes, unpermute_order - ) - - # Use parallelkittens optimized TMA jagged communication - unpermute_outputs = TKJaggedAllToAll.apply( - group, expert_outputs, input_splits, output_splits, max_tokens, ext - ) - - weights_idx = _generate_weights_idx(routing_weights, selected_experts, num_experts) - unpermute_outputs = _unpermute( - unpermute_outputs, - weights_idx, - org_hidden_states_shape, - local_input_permutation_mapping, - routing_map, - ) - return unpermute_outputs - - -def expert_forward( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, -) -> torch.Tensor: - gate = torch.nn.functional.silu(gate_proj(x)) - up = up_proj(x) - return down_proj(gate * up) - - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - One MoE forward. Returns combined expert output. Handles `.backward()` properly. - Uses native ParallelKittens direct PGL operations for dynamic routing bottlenecks. - """ - group = group or dist.group.WORLD - hidden_dim = hidden_states.size(-1) - - ext = _ensure_ext_jit() - - # Router - router_logits = torch.nn.functional.linear( - hidden_states.reshape(-1, hidden_dim), gate_weight, gate_bias - ) - routing_weights, selected_experts = torch.topk( - torch.softmax(router_logits, dim=-1), top_k, dim=-1 - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts - ).permute(2, 1, 0) - - # Preprocess - input_splits, output_splits, num_global_tokens_per_local_expert, _, max_tokens = _preprocess( - expert_mask, num_experts, group - ) - - # Token pre all2all (fused with ThunderKittens TKJaggedAllToAll) - ( - global_permuted_hidden_states, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - ) = token_pre_all2all( - hidden_states, - expert_mask, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - max_tokens, - ext, - group, - ) - - # Local expert (shared MLP) - expert_outputs = expert_forward( - global_permuted_hidden_states, gate_proj, up_proj, down_proj - ) - - # Tokens post all2all (fused with ThunderKittens TKJaggedAllToAll) - out = tokens_post_all2all( - expert_outputs, - routing_weights, - selected_experts, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - max_tokens, - ext, - group, - ) - - return out \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_parallelkittens.py deleted file mode 100755 index d689f69..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_parallelkittens.py +++ /dev/null @@ -1,676 +0,0 @@ -""" -Strategy: -1. Replaced opaque `dist.all_to_all_single` with a custom ParallelKittens TMA-based all-to-all. We dynamically pad the variable-sized token splits to a uniform multiple to fit the ThunderKittens TMA swizzle layout [W, 1, R, C]. The device-side data movement runs fully asynchronously over NVLink without dropping back to host NCCL calls for the variable lengths. -2. Fused the shared expert's LoRA adapters directly into the `gate_proj`, `up_proj`, and `down_proj` weights before running the linear layers. This preserves mathematically identical outputs but completely eliminates the separate `x @ A^T`, `+ B^T`, and `+ base_proj` operations, reducing memory bandwidth by issuing a single dense matmul `F.linear(x, W + B@A)` for all local tokens at once. -3. Kept token routing and index sorting on PyTorch natively, cleanly decoupling the metadata ops from the heavy lifting device-side communication and computation. -""" - -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - if num_tiles == 0: - num_tiles = 1 - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - -def tk_all_to_all_variable( - group: dist.ProcessGroup, - ext, - input_tensor: torch.Tensor, - output_split_sizes: List[int], - input_split_sizes: List[int], -) -> torch.Tensor: - world = dist.get_world_size(group) - if world == 1: - return input_tensor.contiguous() - - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - hidden_dim = input_tensor.size(-1) - - local_max = max(max(input_split_sizes), max(output_split_sizes)) if len(input_split_sizes) > 0 else 0 - local_max_t = torch.tensor([local_max], dtype=torch.int32, device=input_tensor.device) - dist.all_reduce(local_max_t, op=dist.ReduceOp.MAX, group=group) - M = local_max_t.item() - - if M == 0: - return torch.empty((0, hidden_dim), dtype=input_tensor.dtype, device=input_tensor.device) - - rest = M * hidden_dim - r, c, padded_rest = _padded_row_col(rest) - - padded_send = torch.zeros(world, padded_rest, dtype=torch.bfloat16, device=input_tensor.device) - - offset = 0 - for i, size in enumerate(input_split_sizes): - if size > 0: - flat_chunk = input_tensor[offset : offset + size].contiguous().view(-1) - padded_send[i, :flat_chunk.numel()] = flat_chunk.to(torch.bfloat16) - offset += size - - inp_4 = padded_send.view(world, 1, r, c) - - input_tk = get_or_create_parallel_tensor(ext, (world, 1, r, c), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (1, world, r, c), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - n = inp_4.numel() - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - # Scatter axis 0 (batch), gather axis 1 (depth) - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - out_flat = output_tk.data_.reshape(-1)[:n].view(1, world, r, c)[0].reshape(world, padded_rest) - - out_chunks = [] - for i, size in enumerate(output_split_sizes): - if size > 0: - numel = size * hidden_dim - chunk = out_flat[i, :numel].contiguous().view(size, hidden_dim) - out_chunks.append(chunk) - - if len(out_chunks) > 0: - return torch.cat(out_chunks, dim=0).to(input_tensor.dtype) - return torch.empty((0, hidden_dim), dtype=input_tensor.dtype, device=input_tensor.device) - - -def _preprocess( - expert_mask: torch.Tensor, - num_experts: int, - ep_group: dist.ProcessGroup, -) -> Tuple[List[int], List[int], torch.Tensor, torch.Tensor]: - ep_size = ep_group.size() - num_local_experts = num_experts // ep_size - rank = dist.get_rank(ep_group) - num_local_tokens_per_expert = expert_mask.sum(dim=(1, 2)) - input_splits = ( - num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(dim=1).tolist() - ) - num_local_tokens_per_expert_flat = num_local_tokens_per_expert.contiguous().view(-1) - output_size = ep_size * num_local_tokens_per_expert_flat.numel() - num_global_tokens_per_expert_flat = torch.empty( - output_size, - dtype=num_local_tokens_per_expert.dtype, - device=num_local_tokens_per_expert.device, - ) - dist.all_gather_into_tensor( - num_global_tokens_per_expert_flat, num_local_tokens_per_expert_flat, group=ep_group - ) - num_global_tokens_per_expert = num_global_tokens_per_expert_flat.view( - ep_size, num_local_tokens_per_expert.size(0) - ) - start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts - num_global_tokens_per_local_expert = num_global_tokens_per_expert[ - :, start_idx:end_idx - ].contiguous() - output_splits = num_global_tokens_per_local_expert.sum(dim=1).tolist() - num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( - dim=0 - ).to(torch.device("cpu"), non_blocking=True) - num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view( - -1, num_local_experts - ).to(torch.device("cpu"), non_blocking=True) - return ( - input_splits, - output_splits, - num_global_tokens_per_local_expert, - num_global_sum_tokens_per_local_expert, - ) - - -def _permute( - tokens: torch.Tensor, routing_map: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - num_tokens, _ = tokens.shape - num_experts = routing_map.shape[0] - routing_map = routing_map.bool() - token_indices = ( - torch.arange(num_tokens, device=routing_map.device) - .unsqueeze(0) - .expand(num_experts, -1) - ) - sorted_indices = token_indices.masked_select(routing_map) - permuted_input = tokens.index_select(0, sorted_indices) - return permuted_input, sorted_indices - - -def _sort_chunks_by_idxs( - input: torch.Tensor, - split_sizes: Union[torch.Tensor, List[int]], - sorted_idxs: List[int], -) -> torch.Tensor: - if isinstance(split_sizes, torch.Tensor): - split_sizes = split_sizes.tolist() - chunks = torch.split(input, split_sizes, dim=0) - return torch.cat([chunks[i] for i in sorted_idxs], dim=0) - - -def _generate_weights_idx( - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, -) -> torch.Tensor: - num_tokens, topk = routing_weights.shape - weights_idx = torch.zeros( - (num_tokens, num_experts), - dtype=routing_weights.dtype, - device=routing_weights.device, - ) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - return weights_idx - - -def _unpermute( - tokens: torch.Tensor, - routing_weights: torch.Tensor, - hidden_states_shape: torch.Size, - permutation_mapping: torch.Tensor, - routing_map: torch.Tensor, -) -> torch.Tensor: - tokens_weight = routing_weights.T.contiguous().masked_select(routing_map.bool()) - tokens = tokens * tokens_weight.unsqueeze(-1) - hidden_dim = hidden_states_shape[-1] - unpermuted_tokens = torch.zeros( - hidden_states_shape, device=tokens.device, dtype=tokens.dtype - ) - expanded_mapping = permutation_mapping.unsqueeze(1).expand(-1, hidden_dim) - unpermuted_tokens.scatter_add_(0, expanded_mapping, tokens) - return unpermuted_tokens - - -def token_pre_all2all( - hidden_states: torch.Tensor, - expert_mask: torch.Tensor, - num_experts: int, - input_splits: List[int], - output_splits: List[int], - num_global_tokens_per_local_expert: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size]: - group = group or dist.group.WORLD - hidden_dim = hidden_states.size(-1) - hidden_states = hidden_states.reshape(-1, hidden_dim) - org_hidden_states_shape = hidden_states.shape - routing_map = expert_mask.sum(dim=1) - - local_permuted_hidden_states, local_input_permutation_mapping = _permute( - hidden_states, routing_map - ) - - ext = _ensure_ext_jit() - global_permuted_hidden_states = tk_all_to_all_variable( - group, ext, local_permuted_hidden_states, output_splits, input_splits - ) - - num_local_experts = num_experts // dist.get_world_size(group) - permute_order = ( - torch.arange(num_experts).reshape(-1, num_local_experts).T.ravel().tolist() - ) - split_sizes = num_global_tokens_per_local_expert.ravel().tolist() - global_permuted_hidden_states = _sort_chunks_by_idxs( - global_permuted_hidden_states, split_sizes, permute_order - ) - return ( - global_permuted_hidden_states, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - ) - - -def tokens_post_all2all( - expert_outputs: torch.Tensor, - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, - input_splits: List[int], - output_splits: List[int], - num_global_tokens_per_local_expert: torch.Tensor, - routing_map: torch.Tensor, - local_input_permutation_mapping: torch.Tensor, - org_hidden_states_shape: torch.Size, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - num_local_experts = num_experts // dist.get_world_size(group) - unpermute_order = ( - torch.arange(num_experts).reshape(num_local_experts, -1).T.ravel().tolist() - ) - split_sizes = num_global_tokens_per_local_expert.T.ravel().tolist() - expert_outputs = _sort_chunks_by_idxs( - expert_outputs, split_sizes, unpermute_order - ) - - ext = _ensure_ext_jit() - # input/output splits are swapped here since we're returning the tokens. - unpermute_outputs = tk_all_to_all_variable( - group, ext, expert_outputs, input_splits, output_splits - ) - - weights_idx = _generate_weights_idx(routing_weights, selected_experts, num_experts) - unpermute_outputs = _unpermute( - unpermute_outputs, - weights_idx, - org_hidden_states_shape, - local_input_permutation_mapping, - routing_map, - ) - return unpermute_outputs - - -def expert_forward_lora( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - lora_gate_A: torch.Tensor, - lora_gate_B: torch.Tensor, - lora_up_A: torch.Tensor, - lora_up_B: torch.Tensor, - lora_down_A: torch.Tensor, - lora_down_B: torch.Tensor, -) -> torch.Tensor: - """ - Highly optimized shared expert MLP with LoRA rank adapters fused perfectly into weights before - the main linear passes. Replaces a sequence of 3 disconnected GEMMs per linear structure - with 1 single contiguous dense GEMM over fused parameters per step. - """ - W_gate = gate_proj.weight + torch.matmul(lora_gate_B, lora_gate_A) - W_up = up_proj.weight + torch.matmul(lora_up_B, lora_up_A) - W_down = down_proj.weight + torch.matmul(lora_down_B, lora_down_A) - - gate_x = torch.nn.functional.linear(x, W_gate, gate_proj.bias) - up = torch.nn.functional.linear(x, W_up, up_proj.bias) - y = torch.nn.functional.silu(gate_x) * up - out = torch.nn.functional.linear(y, W_down, down_proj.bias) - return out - - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - lora_gate_A: torch.Tensor, - lora_gate_B: torch.Tensor, - lora_up_A: torch.Tensor, - lora_up_B: torch.Tensor, - lora_down_A: torch.Tensor, - lora_down_B: torch.Tensor, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - hidden_dim = hidden_states.size(-1) - - # Router - router_logits = torch.nn.functional.linear( - hidden_states.reshape(-1, hidden_dim), gate_weight, gate_bias - ) - routing_weights, selected_experts = torch.topk( - torch.softmax(router_logits, dim=-1), top_k, dim=-1 - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts - ).permute(2, 1, 0) - - # Preprocess - input_splits, output_splits, num_global_tokens_per_local_expert, _ = _preprocess( - expert_mask, num_experts, group - ) - - # Token pre all2all - ( - global_permuted_hidden_states, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - ) = token_pre_all2all( - hidden_states, - expert_mask, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - group, - ) - - expert_outputs = expert_forward_lora( - global_permuted_hidden_states, - gate_proj, - up_proj, - down_proj, - lora_gate_A, - lora_gate_B, - lora_up_A, - lora_up_B, - lora_down_A, - lora_down_B, - ) - - # Tokens post all2all - out = tokens_post_all2all( - expert_outputs, - routing_weights, - selected_experts, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - routing_map, - local_input_permutation_mapping, - org_hidden_states_shape, - group, - ) - return out - - -def main() -> None: - dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") - group = dist.group.WORLD - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - device = torch.device("cuda", rank) if torch.cuda.is_available() else torch.device("cpu") - - num_experts = 8 - top_k = 2 - hidden_dim = 64 - intermediate_dim = 128 - batch, seq = 2, 16 - num_tokens = batch * seq - assert num_experts % world_size == 0, "num_experts must be divisible by world_size" - - # Synthetic inputs and parameters - torch.manual_seed(42 + rank) - hidden_states = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.float32) - gate_weight = torch.randn(num_experts, hidden_dim, device=device, dtype=torch.float32) - gate_bias = torch.randn(num_experts, device=device, dtype=torch.float32) - gate_proj = torch.nn.Linear(hidden_dim, intermediate_dim).to(device) - up_proj = torch.nn.Linear(hidden_dim, intermediate_dim).to(device) - down_proj = torch.nn.Linear(intermediate_dim, hidden_dim).to(device) - lora_r = 8 - lora_gate_A = torch.randn(lora_r, hidden_dim, device=device, dtype=torch.float32) - lora_gate_B = torch.randn(intermediate_dim, lora_r, device=device, dtype=torch.float32) - lora_up_A = torch.randn(lora_r, hidden_dim, device=device, dtype=torch.float32) - lora_up_B = torch.randn(intermediate_dim, lora_r, device=device, dtype=torch.float32) - lora_down_A = torch.randn(lora_r, intermediate_dim, device=device, dtype=torch.float32) - lora_down_B = torch.randn(hidden_dim, lora_r, device=device, dtype=torch.float32) - - out = solution( - hidden_states, - gate_weight, - gate_bias, - gate_proj, - up_proj, - down_proj, - lora_gate_A, - lora_gate_B, - lora_up_A, - lora_up_B, - lora_down_A, - lora_down_B, - num_experts=num_experts, - top_k=top_k, - group=group, - ) - - if rank == 0: - print("MoE + LoRA forward OK", out.shape) - dist.destroy_process_group() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_parallelkittens.py deleted file mode 100755 index b3d889b..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_parallelkittens.py +++ /dev/null @@ -1,358 +0,0 @@ -""" -Optimized Ulysses all_to_all_tensor primitive using ThunderKittens TMA. - -Strategy: -1. Device-Side Data Movement: Replaces host-driven `torch.tensor_split` + `dist.all_to_all` + `torch.cat` - with a direct, fused TMA scatter-gather kernel mapping directly to the underlying NVSwitch/NVLink fabric. -2. Overlap Communication and Memory Permutations: By permuting and casting into a flat buffer view, we - pack the scattered chunks safely. The TK kernel then asynchronously routes them (`tma::store_async`) - to their destinations simultaneously across all 8 GPUs. -3. Eliminate Redundant Collectives: The customized parallel tensor allocation and unified compilation - avoids unnecessary Python loop overheads, completing the all-to-all purely through device-side - cooperative transfers and barriers. -""" - -import os -from typing import Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ulysses_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - x: torch.Tensor, - scatter_dim: int, - gather_dim: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - w = dist.get_world_size(group) - if w == 1: - return x.contiguous() - - assert w == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={w}" - ) - - x = x.contiguous() - original_shape = x.shape - original_dtype = x.dtype - - scatter_dim = scatter_dim % x.dim() - gather_dim = gather_dim % x.dim() - - if x.numel() == 0: - out_shape = list(original_shape) - out_shape[scatter_dim] //= w - out_shape[gather_dim] *= w - return torch.empty(out_shape, dtype=original_dtype, device=x.device) - - # 1. Prepare scattered dimensions and chunks - shape = list(original_shape) - assert shape[scatter_dim] % w == 0, f"scatter_dim size {shape[scatter_dim]} not divisible by world_size {w}" - shape[scatter_dim] //= w - - new_shape = shape.copy() - new_shape.insert(scatter_dim, w) - - x_view = x.view(new_shape) - - perm = list(range(len(new_shape))) - perm.remove(scatter_dim) - perm.insert(0, scatter_dim) - - # Chunk preparation directly maps identically to `dist.all_to_all_single` - x_contig = x_view.permute(perm).contiguous() - chunk_shape = x_contig.shape[1:] - rest = x_contig[0].numel() - - r, c, padded_rest = _padded_row_col(rest) - - ext = _ensure_ext_jit() - - input_tk = get_or_create_parallel_tensor( - ext, (w, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, w, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=w) - - # Explicit single memory transfer formatting to padded chunks mapped via IPC / NVSwitch - input_tk.data_.reshape(w, padded_rest)[:, :rest].copy_(x_contig.reshape(w, rest)) - - # Launch ThunderKittens TMA All-To-All mapping `batch_idx -> rank`, `depth_idx -> source` - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - # Gather chunks directly from target buffer slicing away padding safely - out_flat = output_tk.data_.reshape(w, padded_rest)[:, :rest] - - # Form fully contiguous output allocating to original memory precision concurrently - out_contig = torch.empty((w, *chunk_shape), dtype=original_dtype, device=x.device) - out_contig.reshape(w, rest).copy_(out_flat) - - # 2. Post-process received chunks via gathering axes semantics - out_view = out_contig.view(w, *shape) - - perm2 = list(range(1, len(shape) + 1)) - perm2.insert(gather_dim, 0) - - out_perm = out_view.permute(perm2).contiguous() - - final_shape = shape.copy() - final_shape[gather_dim] *= w - - return out_perm.view(final_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_parallelkittens.py deleted file mode 100755 index 2b5217a..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_parallelkittens.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Strategy: -To eliminate the overhead of stock `torch.distributed.all_gather_into_tensor`, we use a custom ThunderKittens multicast kernel. -- **Device-Side Multicast**: Instead of moving data through NCCL's ring/tree abstractions, the kernel directly maps a shared parallel tensor across all Hopper GPUs in the node. Each rank reads its local input chunk and uses `st.global.cs` to write directly to the multicast address space. -- **Hardware-Level Broadcast**: The NVLink fabric broadcasts the store to all peer GPUs simultaneously, completing the all-gather in a single memory phase. -- **Zero Host Overhead**: We cache the VMM parallel tensor and barrier allocations, avoiding repetitive IPC exchanges on the hot path, and we vectorize memory operations (using `int4` for 16-byte chunks) to saturate the memory bus. -""" - -import os -from typing import Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_gather entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace all_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - bf16* output_mc_ptr; - const bf16* input_ptr; - size_t chunk_elems; - int dev_idx; - - __host__ inline dim3 grid() const { - size_t int4_count = chunk_elems / 8; // 8 bf16s per int4 - int blocks = (int4_count + config::NUM_THREADS - 1) / config::NUM_THREADS; - // Cap blocks to saturate GPU but allow grid-stride loop - if (blocks > 1024) blocks = 1024; - return dim3(blocks > 0 ? blocks : 1); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t int4_count = G.chunk_elems / 8; - - // Grid-stride loop for arbitrary chunk sizes - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - idx < int4_count; - idx += gridDim.x * blockDim.x) { - - // Read 16 bytes (8 bf16 elements) from local input - int4 val = reinterpret_cast(G.input_ptr)[idx]; - - // Write to multicast pointer via streaming store (bypasses L1, hits L2/fabric) - int4* dst = reinterpret_cast(G.output_mc_ptr + G.dev_idx * G.chunk_elems) + idx; - asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" - : : "l"(dst), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); - } -} - -} // namespace all_gather - -namespace all_gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - torch::Tensor input, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, barrier); - TORCH_CHECK(input.is_contiguous(), "Input must be contiguous"); - TORCH_CHECK(input.dtype() == torch::kBFloat16, "Input must be BF16"); - - size_t chunk_elems = input.numel(); - TORCH_CHECK(chunk_elems % 8 == 0, "Chunk size must be multiple of 8 for int4 vectorization"); - - using parallel_layout = pgl, all_gather::globals::NUM_DEVICES, true>; - auto output_pgl = kittens::py::parallel_tensor_to_pgl(output); - - all_gather::globals ag_G { - .output_mc_ptr = output_pgl.mc_ptr, - .input_ptr = reinterpret_cast(input.data_ptr()), - .chunk_elems = chunk_elems, - .dev_idx = output.local_rank_ - }; - - all_gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // 1. Sync before writing to avoid stomping on previous iterations - kittens::py::launch_kernel(barrier_G); - - // 2. Multicast broadcast - if (chunk_elems > 0) { - kittens::py::launch_kernel(ag_G); - } - - // 3. Sync after writing to ensure all L2s are updated before Python consumes - kittens::py::launch_kernel(barrier_G); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_allgather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call dist.barrier() in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - x: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - - if world_size == 1: - return x.contiguous() - - assert world_size == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world_size}" - ) - - if x.numel() == 0: - dim_size = list(x.size()) - dim_size[0] = dim_size[0] * world_size - return torch.empty(dim_size, dtype=x.dtype, device=x.device) - - x = x.contiguous() - original_dtype = x.dtype - if original_dtype != torch.bfloat16: - x = x.to(torch.bfloat16) - - ext = _ensure_ext_jit() - - dim_size = list(x.size()) - dim_size[0] = dim_size[0] * world_size - - # Flatten the per-rank chunk to process uniformly - flat = x.view(-1) - n = flat.numel() - - # Kernel relies on 16-byte (8 bf16s) vectorization - pad_n = ((n + 7) // 8) * 8 - - if pad_n > n: - padded_x = torch.zeros(pad_n, dtype=x.dtype, device=x.device) - padded_x[:n] = flat - else: - padded_x = flat - - # Request/Re-use VMM mapping spanning all devices in group - output_tk = get_or_create_parallel_tensor( - ext, (world_size * pad_n,), torch.bfloat16, multicast=True - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Perform device-side multicast exchange - ext.tk_all_gather(output_tk, padded_x, barrier_tk) - - if pad_n > n: - # Re-pack padded shards before reshaping to strict primitive spec - out = torch.empty((world_size, pad_n), dtype=x.dtype, device=x.device) - out.copy_(output_tk.data_[: world_size * pad_n].view(world_size, pad_n)) - out_compact = out[:, :n].contiguous().view(dim_size) - - if original_dtype != torch.bfloat16: - return out_compact.to(original_dtype) - return out_compact - else: - # Directly view contiguous array block - out = output_tk.data_[: world_size * n].view(dim_size) - - if original_dtype != torch.bfloat16: - return out.to(original_dtype) - # We must clone because output_tk.data_ is a static pool mapping - # and may be overwritten on the next call to `solution`. - return out.clone() \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_parallelkittens.py deleted file mode 100755 index f94966b..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_parallelkittens.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Strategy: -- **Device-Side Multicast**: Instead of PyTorch's two-phase all-gather (which bounces metadata through host lists and executes repeated peer copies followed by `torch.cat`), we map a single symmetric TK parallel tensor on device via NVSwitch multicast. -- **Compute-Communication Fused Mapping**: Each rank computes the direct destination offset for its chunk and pushes local data to the global multicast pointer (`mc_ptr`). This fuses the network transmission and the `torch.cat` concatenation into a single broadcast step. -- **Vectorized Stores**: Data movement is heavily vectorized using `uint4` (128-bit) stores when the inner dimensions are aligned, maximizing fabric bandwidth and bypassing L1 cache to saturate the NVLink multicast window. -- **Minimal Host Overhead**: We only sync once on the host to negotiate exact shapes, then use device-side barriers (`barrier_all`) to ensure completion before and after the fabric writes. -""" - -import os -import math -import torch -import torch.distributed as dist -from typing import Optional - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Ulysses variable all-gather via Multicast -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace ulysses_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout output; - const bf16* local_input; - - int outer_size; - int rank_dim; - int inner_size; - int total_dim; - int rank_offset; - - __host__ inline dim3 grid() const { - size_t numel = (size_t)outer_size * rank_dim * (inner_size % 8 == 0 ? inner_size / 8 : inner_size); - size_t g = (numel + config::NUM_THREADS - 1) / config::NUM_THREADS; - return dim3(g > 0 ? std::min((size_t)65536, g) : 1); - } -}; - -__device__ inline void kernel_vec8(const globals &G) { - size_t vec_inner = G.inner_size / 8; - size_t chunk_size = (size_t)G.rank_dim * vec_inner; - size_t total_elems = (size_t)G.outer_size * chunk_size; - - using vec_t = uint4; // 16 bytes = 8 bf16 - const vec_t* in_vec = reinterpret_cast(G.local_input); - vec_t* out_vec = reinterpret_cast(G.output.mc_ptr); - - // Grid-stride loop mapped directly over 128-bit boundary offsets - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_elems; tid += gridDim.x * blockDim.x) { - size_t o = tid / chunk_size; - size_t idx_in_chunk = tid % chunk_size; - - size_t in_idx = o * chunk_size + idx_in_chunk; - size_t out_idx = o * ((size_t)G.total_dim * vec_inner) + ((size_t)G.rank_offset * vec_inner) + idx_in_chunk; - - out_vec[out_idx] = in_vec[in_idx]; - } -} - -__device__ inline void kernel_scalar(const globals &G) { - size_t chunk_size = (size_t)G.rank_dim * G.inner_size; - size_t total_elems = (size_t)G.outer_size * chunk_size; - - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_elems; tid += gridDim.x * blockDim.x) { - size_t o = tid / chunk_size; - size_t idx_in_chunk = tid % chunk_size; - - size_t in_idx = o * chunk_size + idx_in_chunk; - size_t out_idx = o * ((size_t)G.total_dim * G.inner_size) + ((size_t)G.rank_offset * G.inner_size) + idx_in_chunk; - - G.output.mc_ptr[out_idx] = G.local_input[in_idx]; - } -} - -} // namespace ulysses_gather - -namespace gather_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - long long local_input_ptr, - kittens::py::TKParallelTensor &barrier, - int outer_size, - int rank_dim, - int inner_size, - int total_dim, - int rank_offset -) { - kittens::py::parallel_tensor_check(output, barrier); - - ulysses_gather::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .local_input = reinterpret_cast(local_input_ptr), - .outer_size = outer_size, - .rank_dim = rank_dim, - .inner_size = inner_size, - .total_dim = total_dim, - .rank_offset = rank_offset - }; - - gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronize to ensure safety of symmetric buffers - kittens::py::launch_kernel(barrier_G); - - if (rank_dim > 0) { - if (inner_size % 8 == 0) { - kittens::py::launch_kernel(G); - } else { - kittens::py::launch_kernel(G); - } - } - - // Synchronize to ensure fabric writes have completed before the PyTorch stream continues - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_ulysses_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ulysses_gather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def solution( - x: torch.Tensor, - gather_dim: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - if world_size == 1: - return x.contiguous() - - assert world_size == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world_size}" - ) - - device = x.device - original_dtype = x.dtype - - # Negative dim support and contiguous memory layout - if gather_dim < 0: - gather_dim += x.ndim - x = x.contiguous() - - # 1. Gather sizes to negotiate the symmetric tensor bounds - # (Required host sync to compute proper shapes and offsets) - x_size = torch.tensor([x.shape[gather_dim]], dtype=torch.int64, device=device) - size_list = [torch.zeros(1, dtype=torch.int64, device=device) for _ in range(world_size)] - dist.all_gather(size_list, x_size, group=group) - sizes = [s.item() for s in size_list] - - # Calculate flat dimensions equivalent to the block copy - outer_size = math.prod(x.shape[:gather_dim]) if gather_dim > 0 else 1 - inner_size = math.prod(x.shape[gather_dim+1:]) if gather_dim < x.ndim - 1 else 1 - - total_dim = sum(sizes) - rank_dim = sizes[rank] - rank_offset = sum(sizes[:rank]) - - # Pre-calculate intended output shape - out_shape = list(x.shape) - out_shape[gather_dim] = total_dim - - ext = _ensure_ext_jit() - - # Cast strictly to kernel execution type - x_bf16 = x.to(torch.bfloat16) - - # 2. ThunderKittens allocation / mapping - total_elems = outer_size * total_dim * inner_size - padded_elems = ((total_elems + 1023) // 1024) * 1024 - if padded_elems == 0: - padded_elems = 1024 - - output_tk = get_or_create_parallel_tensor( - ext, - (padded_elems,), - torch.bfloat16, - multicast=True - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # 3. Kernel launch mapping direct local chunk offset into the multicast peer layout - ext.tk_ulysses_gather( - output_tk, - x_bf16.data_ptr(), - barrier_tk, - outer_size, - rank_dim, - inner_size, - total_dim, - rank_offset - ) - - # Slice output precisely based on exact elements written dynamically - result = output_tk.data_.view(-1)[:total_elems].view(*out_shape).clone() - - return result.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_parallelkittens.py deleted file mode 100755 index c9cf977..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_parallelkittens.py +++ /dev/null @@ -1,363 +0,0 @@ -""" -ThunderKittens Ulysess All-to-All using TMA between devices. - -Replaces the NCCL / stock PyTorch `dist.all_to_all_single` with a dedicated -device-resident ThunderKittens all-to-all. By extracting the exact chunking -dimensions for seq and heads and preparing them via fast device `.movedim()`, -we flatten the communication into a single asynchronous TMA personalized -routing step across symmetric memory on the NVLink peers. - -Strategy: -1. Reshape and `movedim` to hoist the chunked `scatter_dim` into the batch dimension (dim 0). -2. Use TMA-based `tk_all_to_all` to stream chunks asynchronously to peers without host NCCL queues. -3. Inverse `movedim` to collapse the received chunks seamlessly into `gather_dim`. -4. Keeps orchestration strictly on the device for lower latency and better scaling. -""" - -import os -from typing import Optional - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ulysses_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -def prep_for_all_to_all(x: torch.Tensor, scatter_dim: int, w: int) -> torch.Tensor: - scatter_dim = scatter_dim if scatter_dim >= 0 else x.dim() + scatter_dim - shape = list(x.shape) - assert shape[scatter_dim] % w == 0, f"scatter_dim {scatter_dim} not divisible by {w}" - shape.insert(scatter_dim + 1, shape[scatter_dim] // w) - shape[scatter_dim] = w - x_reshaped = x.reshape(shape) - x_moved = x_reshaped.movedim(scatter_dim, 0) - return x_moved.contiguous() - - -def post_for_all_to_all(out_moved: torch.Tensor, gather_dim: int, orig_dim: int) -> torch.Tensor: - gather_dim = gather_dim if gather_dim >= 0 else orig_dim + gather_dim - out_shifted = out_moved.movedim(0, gather_dim) - final_shape = list(out_shifted.shape) - final_shape[gather_dim] = final_shape[gather_dim] * final_shape[gather_dim + 1] - final_shape.pop(gather_dim + 1) - return out_shifted.reshape(final_shape) - - -@torch.no_grad() -def solution( - x: torch.Tensor, - seq_dim: int, - head_dim: int, - group: Optional[ProcessGroup] = None, - unpadded_dim_size: int = 0, -) -> torch.Tensor: - if group is None: - return x - - assert x.is_cuda and x.is_contiguous() - - world = dist.get_world_size(group) - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - ext = _ensure_ext_jit() - - original_dtype = x.dtype - orig_dim = x.dim() - - # Scatter and chunk prep via device pointer reshaping - x_moved = prep_for_all_to_all(x, scatter_dim=head_dim, w=world) - chunk_shape = x_moved.shape[1:] - rest = x_moved[0].numel() - r, c, padded_rest = _padded_row_col(rest) - - # Convert to BF16 for Kernel alignment - x_bf16 = x_moved.to(torch.bfloat16) - padded = torch.zeros(world, padded_rest, dtype=torch.bfloat16, device=x.device) - padded[:, :rest] = x_bf16.reshape(world, rest) - inp_4 = padded.view(world, 1, r, c) - - # Use TK ParallelTensors for device-side peer transfers - input_tk = get_or_create_parallel_tensor( - ext, (world, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, world, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Assign and run fast TMA all-to-all - n = inp_4.numel() - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - # Recover gathered output and restore shapes correctly - out_flat = ( - output_tk.data_.reshape(-1)[:n] - .view(1, world, r, c)[0] - .reshape(world, padded_rest)[:, :rest] - .contiguous() - ) - out_moved = out_flat.reshape(world, *chunk_shape).to(original_dtype) - x = post_for_all_to_all(out_moved, gather_dim=seq_dim, orig_dim=orig_dim) - - # Apply unpad logic corresponding precisely to the original Ulysses seq-parallel code - if unpadded_dim_size and unpadded_dim_size % world != 0: - padding_size = x.size(seq_dim) - unpadded_dim_size - slc = [slice(None)] * x.dim() - slc[seq_dim] = slice(0, -padding_size) - x = x[tuple(slc)] - - return x \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_parallelkittens.py deleted file mode 100755 index 6be2614..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_parallelkittens.py +++ /dev/null @@ -1,282 +0,0 @@ -import os -from typing import Optional - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded C++ / CUDA source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -struct Pointers { - __nv_bfloat16* ptrs[8]; -}; - -template -__global__ void p2p_alltoall_kernel( - Pointers out_ptrs, - const __nv_bfloat16* in_ptr, - int D0, int D1, int D2, int D3, int D4, - int W, int src_rank, bool S_less_than_G, - size_t total_vecs -) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_vecs) return; - - // Convert 1D idx to 5D indices - int i4 = idx % D4; - size_t tmp = idx / D4; - int i3 = tmp % D3; - tmp /= D3; - int i2 = tmp % D2; - tmp /= D2; - int i1 = tmp % D1; - int i0 = tmp / D1; - - int dst_rank; - size_t out_idx; - - if (S_less_than_G) { - // D1 is Scatter, D3 is Gather - int chunk_s = D1 / W; - dst_rank = i1 / chunk_s; - int i1_out = i1 % chunk_s; - int i3_out = src_rank * D3 + i3; - - // out_shape = [D0, chunk_s, D2, D3 * W, D4] - out_idx = ((((size_t)i0 * chunk_s + i1_out) * D2 + i2) * (D3 * W) + i3_out) * D4 + i4; - } else { - // D1 is Gather, D3 is Scatter - int chunk_s = D3 / W; - dst_rank = i3 / chunk_s; - int i3_out = i3 % chunk_s; - int i1_out = src_rank * D1 + i1; - - // out_shape = [D0, D1 * W, D2, chunk_s, D4] - out_idx = ((((size_t)i0 * (D1 * W) + i1_out) * D2 + i2) * chunk_s + i3_out) * D4 + i4; - } - - // Direct P2P NVLink write to destination symmetric buffer - reinterpret_cast(out_ptrs.ptrs[dst_rank])[out_idx] = reinterpret_cast(in_ptr)[idx]; -} - -namespace all_to_all_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - torch::Tensor input, - kittens::py::TKParallelTensor &barrier, - int D0, int D1, int D2, int D3, int D4, - bool S_less_than_G -) { - TORCH_CHECK(input.is_contiguous(), "Input must be contiguous"); - TORCH_CHECK(input.scalar_type() == torch::kBFloat16, "Input must be bfloat16"); - - int W = all_to_all_barrier::globals::NUM_DEVICES; - int src_rank = barrier.local_rank_; - - Pointers out_ptrs; - for (int i = 0; i < W; ++i) { - out_ptrs.ptrs[i] = reinterpret_cast<__nv_bfloat16*>(output.data_ptrs_[i]); - } - const __nv_bfloat16* in_ptr = reinterpret_cast(input.data_ptr()); - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Sync before P2P writes - kittens::py::launch_kernel(barrier_G); - - size_t total_elems = (size_t)D0 * D1 * D2 * D3 * D4; - int threads = 256; - - if (total_elems > 0) { - // Vectorize along the innermost contiguous dimension - if (D4 % 8 == 0) { - size_t total_vecs = total_elems / 8; - int blocks = (total_vecs + threads - 1) / threads; - p2p_alltoall_kernel<<>>( - out_ptrs, in_ptr, D0, D1, D2, D3, D4 / 8, W, src_rank, S_less_than_G, total_vecs); - } else if (D4 % 4 == 0) { - size_t total_vecs = total_elems / 4; - int blocks = (total_vecs + threads - 1) / threads; - p2p_alltoall_kernel<<>>( - out_ptrs, in_ptr, D0, D1, D2, D3, D4 / 4, W, src_rank, S_less_than_G, total_vecs); - } else if (D4 % 2 == 0) { - size_t total_vecs = total_elems / 2; - int blocks = (total_vecs + threads - 1) / threads; - p2p_alltoall_kernel<<>>( - out_ptrs, in_ptr, D0, D1, D2, D3, D4 / 2, W, src_rank, S_less_than_G, total_vecs); - } else { - size_t total_vecs = total_elems; - int blocks = (total_vecs + threads - 1) / threads; - p2p_alltoall_kernel<__nv_bfloat16><<>>( - out_ptrs, in_ptr, D0, D1, D2, D3, D4, W, src_rank, S_less_than_G, total_vecs); - } - } - - // Sync after P2P writes to ensure visibility on destination ranks - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ALIGNMENT = 1024 * 1024 # Cache TKParallelTensor via aligned shapes - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ulysses_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _pad_tensor(x: torch.Tensor, dim: int, padding_size: int, padding_value: int = 0) -> torch.Tensor: - shape = list(x.shape) - shape[dim] = padding_size - pad = torch.full(shape, padding_value, dtype=x.dtype, device=x.device) - return torch.cat([x, pad], dim=dim) - - -@torch.no_grad() -def solution( - x: torch.Tensor, - seq_dim: int, - head_dim: int, - group: Optional[ProcessGroup] = None, -) -> torch.Tensor: - if group is None: - return x - - assert x.is_cuda, "Input must be on CUDA device" - - sp_world = dist.get_world_size(group) - assert sp_world == NUM_DEVICES, f"This ThunderKittens kernel assumes NUM_DEVICES={NUM_DEVICES}" - - dim_size = x.size(seq_dim) - if dim_size % sp_world != 0: - padding_size = sp_world - (dim_size % sp_world) - x = _pad_tensor(x, seq_dim, padding_size) - - ext = _ensure_ext_jit() - - original_dtype = x.dtype - x_bf16 = x.to(torch.bfloat16).contiguous() - shape = list(x_bf16.shape) - n = x_bf16.numel() - - # Pre-calculate 5D logical bounds based on scatter/gather layout splits - s_dim = seq_dim - g_dim = head_dim - S_less_than_G = (s_dim < g_dim) - - D0, D1, D2, D3, D4 = 1, 1, 1, 1, 1 - min_dim = min(s_dim, g_dim) - max_dim = max(s_dim, g_dim) - - for s in shape[:min_dim]: D0 *= s - D1 = shape[min_dim] - for s in shape[min_dim+1:max_dim]: D2 *= s - D3 = shape[max_dim] - for s in shape[max_dim+1:]: D4 *= s - - out_shape = list(shape) - out_shape[s_dim] = shape[s_dim] // sp_world - out_shape[g_dim] = shape[g_dim] * sp_world - - # Cache aligned parallel tensor mapping to symmetric memory for IPC handling - padded_n = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - output_tk = get_or_create_parallel_tensor( - ext, (padded_n,), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=sp_world) - - # Perform P2P NVLink scatter+gather transpose via custom device barrier synchronization - ext.tk_all_to_all( - output_tk, - x_bf16, - barrier_tk, - D0, D1, D2, D3, D4, - S_less_than_G - ) - - out = output_tk.data_[:n].clone() - return out.view(out_shape).to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_parallelkittens.py deleted file mode 100755 index 69222ef..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_parallelkittens.py +++ /dev/null @@ -1,376 +0,0 @@ -import os -import math -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -from torch import Tensor -from torch.distributed import ProcessGroup - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - - -# --------------------------------------------------------------------------- -# Embedded .cu source: Fused Reshape & All-To-All over NVLink PGL -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace fused_all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - // We treat the buffer as a flat array via gl to handle arbitrary striding - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - - int B; - int S; - int R; - int C; - int dev_idx; - int total_vec; - - __host__ inline dim3 grid() const { - return dim3((total_vec + config::NUM_THREADS - 1) / config::NUM_THREADS); - } -}; - -template struct Vec; -template <> struct Vec<8> { using type = float4; }; -template <> struct Vec<4> { using type = float2; }; -template <> struct Vec<2> { using type = float; }; -template <> struct Vec<1> { using type = uint16_t; }; - -template -__device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= G.total_vec || G.C == 0) return; - - int C_vec = G.C / VEC_SIZE; - - int c_v = idx % C_vec; - int tmp = idx / C_vec; - int r = tmp % G.R; - tmp /= G.R; - int s = tmp % G.S; - int b = tmp / G.S; - - int c = c_v * VEC_SIZE; - - // Chunk boundary logic - int w_dst = c / (G.C / globals::NUM_DEVICES); - int c_out = c % (G.C / globals::NUM_DEVICES); - - // Sequence gather concatenation logic - int s_out = G.dev_idx * G.S + s; - - // Flatten 4D indices to 1D offsets - int src_offset = ((b * G.S + s) * G.R + r) * G.C + c; - int dst_offset = ((b * (G.S * globals::NUM_DEVICES) + s_out) * G.R + r) * (G.C / globals::NUM_DEVICES) + c_out; - - using V = typename Vec::type; - V val = *reinterpret_cast(&G.input[G.dev_idx].data[src_offset]); - *reinterpret_cast(&G.output[w_dst].data[dst_offset]) = val; -} - -} // namespace fused_all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int B, int S, int R, int C -) { - kittens::py::parallel_tensor_check(output, input, barrier); - - int C_per_W = C / fused_all_to_all::globals::NUM_DEVICES; - - // Opt for highest aligned vector size to saturate NVLink - int vec_size = 1; - if (C_per_W > 0 && C_per_W % 8 == 0) vec_size = 8; - else if (C_per_W > 0 && C_per_W % 4 == 0) vec_size = 4; - else if (C_per_W > 0 && C_per_W % 2 == 0) vec_size = 2; - - fused_all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .B = B, - .S = S, - .R = R, - .C = C, - .dev_idx = input.local_rank_, - .total_vec = (B * S * R * C) / vec_size - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (vec_size == 8) { - kittens::py::launch_kernel>(all_to_all_G); - } else if (vec_size == 4) { - kittens::py::launch_kernel>(all_to_all_G); - } else if (vec_size == 2) { - kittens::py::launch_kernel>(all_to_all_G); - } else { - kittens::py::launch_kernel>(all_to_all_G); - } - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_fused_all_to_all_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -# --------------------------------------------------------------------------- -# Fallback Reference Implementations (Exact compatibility for non-H100/world!=8) -# --------------------------------------------------------------------------- -def _pad_tensor(x: Tensor, dim: int, padding_size: int, padding_value: int = 0) -> Tensor: - shape = list(x.shape) - shape[dim] = padding_size - pad = torch.full(shape, padding_value, dtype=x.dtype, device=x.device) - return torch.cat([x, pad], dim=dim) - -def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: - slc = [slice(None)] * len(x.shape) - slc[dim] = slice(0, -padding_size) - return x[tuple(slc)] - -def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): - group = group or dist.group.WORLD - sp_world_size = dist.get_world_size(group) - if scatter_dim != 0: - gather_dim_bef = x.shape[gather_dim] - scatter_dim_bef = x.shape[scatter_dim] - x = (x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) - .transpose(0, 1) - .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) - .contiguous()) - output = torch.empty_like(x) - comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op) - if scatter_dim == 0: - output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim) - return output - -def _all_to_all(local_input: Tensor, scatter_dim: int, gather_dim: int, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): - group = group or dist.group.WORLD - seq_world_size = dist.get_world_size(group) - input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) - return torch.cat(output_list, dim=gather_dim).contiguous() - -def _all_to_all_tensor(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False): - if scatter_dim <= 1 and gather_dim <= 1: - return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op) - return _all_to_all(x, scatter_dim, gather_dim, group, async_op) - -class _SeqAllToAll(torch.autograd.Function): - @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, scatter_dim: int, gather_dim: int, async_op: bool) -> Tensor: - ctx.group = group - ctx.scatter_dim = scatter_dim - ctx.gather_dim = gather_dim - ctx.async_op = async_op - return _all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) - -def gather_seq_scatter_heads_qkv( - qkv_tensor: Tensor, seq_dim: int, unpadded_dim_size: Optional[int] = None, - restore_shape: bool = True, async_op: bool = False, group: Optional[ProcessGroup] = None -) -> Tensor: - group = group or dist.group.WORLD - if not group: return qkv_tensor - sp_world = dist.get_world_size(group) - orig_shape = qkv_tensor.shape - scatter_dim = qkv_tensor.dim() - bef_all2all_shape = list(orig_shape) - qkv_proj_dim = bef_all2all_shape[-1] - bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] - qkv_tensor = qkv_tensor.view(bef_all2all_shape) - qkv_tensor = _SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) - if restore_shape: - out_shape = list(orig_shape) - out_shape[seq_dim] *= sp_world - out_shape[-1] = qkv_proj_dim // sp_world - qkv_tensor = qkv_tensor.view(out_shape) - if unpadded_dim_size and unpadded_dim_size % sp_world != 0: - padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size - qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) - return qkv_tensor - - -# --------------------------------------------------------------------------- -# ParallelKittens Implementation -# --------------------------------------------------------------------------- -@torch.no_grad() -def solution( - qkv_tensor: torch.Tensor, - seq_dim: int, - group: Optional[ProcessGroup] = None, - unpadded_dim_size: Optional[int] = None, - restore_shape: bool = True, -) -> torch.Tensor: - group = group or dist.group.WORLD - if not group: - return qkv_tensor - - sp_world = dist.get_world_size(group) - if sp_world == 1: - return qkv_tensor - - orig_shape = qkv_tensor.shape - qkv_proj_dim = orig_shape[-1] - bef_shape = list(orig_shape[:-1]) + [3, qkv_proj_dim // 3] - - # Calculate collapsed logical dimensions - B = math.prod(bef_shape[:seq_dim]) if seq_dim > 0 else 1 - S = bef_shape[seq_dim] - R = math.prod(bef_shape[seq_dim + 1 : -1]) - C = bef_shape[-1] - numel = B * S * R * C - - # The custom ThunderKittens collective is hard-coded for world_size=8 - # Fallback to pure PyTorch NCCL path if constraints are missed (e.g. padding not divisible) - if sp_world != 8 or C % sp_world != 0 or numel == 0: - return gather_seq_scatter_heads_qkv( - qkv_tensor, - seq_dim=seq_dim, - unpadded_dim_size=unpadded_dim_size, - restore_shape=restore_shape, - async_op=False, - group=group, - ) - - ext = _ensure_ext_jit() - - original_dtype = qkv_tensor.dtype - if original_dtype != torch.bfloat16: - qkv_tensor = qkv_tensor.to(torch.bfloat16) - - input_tk = get_or_create_parallel_tensor(ext, (numel,), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (numel,), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=sp_world) - - # Coalesce directly into symmetric allocation space - input_tk.data_[:numel].copy_(qkv_tensor.contiguous().view(-1)) - - # Single-step P2P gather/scatter - ext.tk_fused_all_to_all(output_tk, input_tk, barrier_tk, B, S, R, C) - - out_bef_shape = bef_shape.copy() - out_bef_shape[seq_dim] *= sp_world - out_bef_shape[-1] //= sp_world - - out_tensor = output_tk.data_[:numel].clone().view(out_bef_shape) - - if restore_shape: - out_shape = list(orig_shape) - out_shape[seq_dim] *= sp_world - out_shape[-1] = qkv_proj_dim // sp_world - out_tensor = out_tensor.view(out_shape) - - if unpadded_dim_size and unpadded_dim_size % sp_world != 0: - padding_size = out_tensor.size(seq_dim) - unpadded_dim_size - out_tensor = _unpad_tensor(out_tensor, seq_dim, padding_size) - - return out_tensor.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_parallelkittens.py deleted file mode 100755 index 7d64b07..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_parallelkittens.py +++ /dev/null @@ -1,227 +0,0 @@ -""" -Strategy: -- Use ParallelKittens `TKParallelTensor` with NVSwitch multicast to broadcast data to all ranks simultaneously in a single NVLink hop. -- The root rank uses `uint4` (16-byte) write-through stores to the multicast pointer, achieving peak device-side memory bandwidth and directly writing to all destination devices. -- Avoids host round-trips and multiple `torch.distributed` operations by synchronizing via ThunderKittens device-side `barrier_all` before and after the multicast stores, overlapping the physical broadcast across all SMs. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (broadcast entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace broadcast { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 8; // 16 bytes = 8 bf16s - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - bf16* local_ptr; - const int dev_idx; - const int src_rank; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK); - } -}; - -// Use 16-byte write-through global stores to the multicast address space -__device__ inline void st_uint4(void* ptr, const uint4& val) { - asm volatile("st.global.wt.v4.b32 [%0], {%1, %2, %3, %4};" - : - : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory"); -} - -__device__ inline void kernel(const globals &G) { - // Only the source rank streams its local data into the NVSwitch multicast pointer - if (G.dev_idx == G.src_rank) { - const size_t idx = globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - // Data is pre-padded so `idx` will always fall within bounds - uint4 tmp = *reinterpret_cast(&G.local_ptr[idx]); - st_uint4(&G.tensor.mc_ptr[idx], tmp); - } -} - -} // namespace broadcast - -namespace broadcast_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace broadcast_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier, - int src_rank -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % broadcast::globals::NUM_ELEMS_PER_BLOCK == 0, - "The total number of tensor elements must be divisible by NUM_ELEMS_PER_BLOCK"); - - broadcast::globals broadcast_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .local_ptr = reinterpret_cast(tensor.data_.data_ptr()), - .dev_idx = tensor.local_rank_, - .src_rank = src_rank - }; - - broadcast_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(broadcast_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_broadcast", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 8 # float4 = 16 bytes = 8 bf16s -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_ELEMS_PER_BLOCK # 2048 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_broadcast_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, f"Expected {NUM_DEVICES} ranks, got {world}." - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - n = flat.numel() - - if n == 0: - return torch.empty(original_shape, dtype=original_dtype, device=tensor.device) - - # Pad out to allow uniform fast 16-byte accesses across all blocks without bounds checking - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - # Request the cached, symmetric VMM-allocated tensor - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Copy input into the VMM-allocated parallel tensor, but only on the broadcasting root - if dist.get_rank() == src: - tensor_tk.data_[:n] = flat - if padded > n: - tensor_tk.data_[n:].zero_() - - # Issue device-side broadcast: - # 1. barrier_all (syncs local write) -> 2. broadcast to multicasts -> 3. barrier_all (syncs read) - ext.tk_broadcast(tensor_tk, barrier_tk, src) - - # Harvest the received data - result = tensor_tk.data_[:n].clone() - return result.to(original_dtype).reshape(original_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_parallelkittens.py deleted file mode 100755 index 4afa5fd..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_parallelkittens.py +++ /dev/null @@ -1,346 +0,0 @@ -import os -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -import torch.distributed as dist -from torch import Tensor -from torch.distributed import ProcessGroup - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded ThunderKittens Kernel -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 64; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - int row_start; - int num_rows; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (num_rows / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.num_rows / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.num_rows / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.num_rows / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.num_rows / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - // Shift row block for chunking - row_block_idx += G.row_start / globals::ROW_BLOCK_SIZE; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else if constexpr (SCATTER_AXIS == 3) { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } else { - dst_dev_idx = 0; // Unused - } - - if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else if constexpr (GATHER_AXIS == 3) { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis, - int row_start, - int num_rows -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_, - .row_start = row_start, - .num_rows = num_rows - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ulysses_a2a_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _local_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float, - causal: bool = False, -) -> torch.Tensor: - # Retained as fallback for single GPU runs - scores = torch.matmul(q, k.transpose(-2, -1)) * scale - if causal and q.size(1) > 1: - S = scores.size(-1) - causal_mask = torch.triu( - torch.ones(S, S, device=scores.device, dtype=torch.bool), - diagonal=1, - ) - scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) - attn = F.softmax(scores, dim=-1) - return torch.matmul(attn, v) - - -@torch.no_grad() -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, - num_heads: int = 8, - causal: bool = False, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - - if world_size == 1: - B, S_local, H = hidden_states.shape - head_dim = H // num_heads - qkv = F.linear(hidden_states, w_qkv) - qkv = qkv.view(B, S_local, 3, num_heads, head_dim) - q, k, v = qkv.unbind(2) - scale = head_dim**-0.5 - attn_out = _local_attention(q, k, v, scale, causal=causal) - out = attn_out.reshape(B, S_local, -1) - return F.linear(out, w_o) - - B, S_local, H = hidden_states.shape - head_dim = (w_qkv.shape[0] // 3) // num_heads - chunk_size = (num_heads // world_size) * head_dim - - assert num_heads % world_size == 0, "num_heads must be divisible by world_size" - assert S_local % 16 == 0, "S_local must be a multiple of 16 for TK alignments" - assert chunk_size % 64 == 0, "chunk_size must be a multiple of 64 for TK alignments" - - ext = _ensure_ext_jit() - - # Pre-allocate TK buffers out of VMM - tk_qkv_in = get_or_create_parallel_tensor(ext, (B, 1, S_local * 3, world_size * chunk_size), torch.bfloat16, multicast=False) - tk_qkv_out = get_or_create_parallel_tensor(ext, (B, 1, S_local * world_size * 3, chunk_size), torch.bfloat16, multicast=False) - - tk_out_in = get_or_create_parallel_tensor(ext, (B, 1, S_local * world_size, chunk_size), torch.bfloat16, multicast=False) - tk_out_out = get_or_create_parallel_tensor(ext, (B, 1, S_local, world_size * chunk_size), torch.bfloat16, multicast=False) - - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - NUM_CHUNKS = 2 - if S_local % NUM_CHUNKS != 0: - NUM_CHUNKS = 1 - S_chunk = S_local // NUM_CHUNKS - - s1 = torch.cuda.current_stream() - s2 = torch.cuda.Stream() - events_compute = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - events_comm = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - - # Compute QKV efficiently overlapping with first All-to-All via sequence chunks - w_qkv_t = w_qkv.t() - tk_qkv_in_view = tk_qkv_in.data_[:B * S_local * 3 * H].view(B, S_local, -1) - - for c in range(NUM_CHUNKS): - start = c * S_chunk - end = (c + 1) * S_chunk - - # F.linear naturally outputs [B, S_local, 3 * num_heads * head_dim] mapping implicitly to TK coordinates - torch.matmul(hidden_states[:, start:end, :], w_qkv_t, out=tk_qkv_in_view[:, start:end, :]) - events_compute[c].record(s1) - - s2.wait_event(events_compute[c]) - with torch.cuda.stream(s2): - ext.tk_all_to_all(tk_qkv_out, tk_qkv_in, barrier_tk, 3, 2, start * 3, S_chunk * 3) - events_comm[c].record(s2) - - for c in range(NUM_CHUNKS): - s1.wait_event(events_comm[c]) - - # Extract fully gathered Q, K, V elements from parallel buffer. - qkv = tk_qkv_out.data_[:B * S_local * world_size * 3 * chunk_size].view(B, S_local * world_size, 3, num_heads // world_size, head_dim) - q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] - - # Inline Local Attention - scale = head_dim**-0.5 - scores = torch.matmul(q, k.transpose(-2, -1)) * scale - if causal and q.size(1) > 1: - S = scores.size(-1) - causal_mask = torch.triu( - torch.ones(S, S, device=scores.device, dtype=torch.bool), - diagonal=1, - ) - scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) - attn = F.softmax(scores, dim=-1) - - # Write attention out to next TK tensor VMM, mapping back to flattened coordinates - tk_out_in_view = tk_out_in.data_[:B * S_local * world_size * chunk_size].view(B, S_local * world_size, num_heads // world_size, head_dim) - torch.matmul(attn, v, out=tk_out_in_view) - - # Run second All-to-All gathering heads and scattering sequence segments. - ext.tk_all_to_all(tk_out_out, tk_out_in, barrier_tk, 2, 3, 0, S_local * world_size) - - # Perform output projection - out = tk_out_out.data_[:B * S_local * world_size * chunk_size].view(B, S_local, -1) - return F.linear(out, w_o) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_parallelkittens.py deleted file mode 100755 index 6519ac2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_parallelkittens.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -Strategy: -Exploiting Hopper's multimem capabilities natively provides hardware-accelerated collective routines. We utilize ThunderKittens to issue `multimem::ld_reduce` across the NVLink domain to pull and sum the data from peers, then use `multimem::st` to broadcast it back. This processes a full sum internally (similarly to an all-reduce) but avoids complex point-to-point reduce kernels, fully exploiting NVSwitch multicast/multimem hardware (which achieves full bisection bandwidth natively). To conform strictly to the `dist.reduce` spec and save memory traffic, only the destination rank (`dst`) actually formats and returns the result from the VMM buffer, leaving non-destination nodes with virtually no post-kernel processing. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - // Hardware-accelerated NVSwitch multimem load + reduce - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - // Multicast store to write the reduced chunks to all participants - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace reduce - -namespace reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace reduce_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - TORCH_CHECK(tensor.data_.numel() % (reduce::globals::NUM_DEVICES * reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - reduce::globals reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_reduce", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_reduce_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - tensor: torch.Tensor, - dst: int = 0 -) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - rank = dist.get_rank() - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - n = tensor.numel() - if n == 0: - return tensor.clone() - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - - # Pad to kernel alignment (NUM_DEVICES * NUM_ELEMS_PER_BLOCK) - padded = ((n + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - # Cached TKParallelTensor (VMM + multicast) — NVSwitch hardware will process this. - tensor_tk = get_or_create_parallel_tensor(ext, (padded,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Copy input into the VMM-allocated parallel tensor - tensor_tk.data_[:n] = flat - if padded > n: - tensor_tk.data_[n:].zero_() - - # Run the TK device-side reduction kernel - ext.tk_reduce(tensor_tk, barrier_tk) - - if rank == dst: - # Materialize the summed result out of VMM buffer only on the target rank - result = tensor_tk.data_[:n].clone() - return result.to(original_dtype).reshape(original_shape) - else: - # Non-destination ranks can just return their identical input shape - return tensor.clone() \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_parallelkittens.py deleted file mode 100755 index 60f943a..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_parallelkittens.py +++ /dev/null @@ -1,336 +0,0 @@ -import os -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for ParallelKittens TMA Fetch and Barrier -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK = 64; - static constexpr int COL_BLOCK = 64; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout k_pgl; - parallel_layout v_pgl; - - gl k_local; - gl v_local; - - int peer_rank; - - __host__ inline dim3 grid() const { - return dim3( - (k_pgl.cols() / COL_BLOCK) * - (k_pgl.rows() / ROW_BLOCK) * - k_pgl.depth() * - k_pgl.batch() - ); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) * 2 + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - - globals::shared_tile &k_tile = allocator.allocate(); - globals::shared_tile &v_tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.k_pgl.depth() * (G.k_pgl.rows() / globals::ROW_BLOCK) * (G.k_pgl.cols() / globals::COL_BLOCK)); - task_idx %= (G.k_pgl.depth() * (G.k_pgl.rows() / globals::ROW_BLOCK) * (G.k_pgl.cols() / globals::COL_BLOCK)); - - int depth_idx = task_idx / (G.k_pgl.rows() / globals::ROW_BLOCK * (G.k_pgl.cols() / globals::COL_BLOCK)); - task_idx %= (G.k_pgl.rows() / globals::ROW_BLOCK * (G.k_pgl.cols() / globals::COL_BLOCK)); - - int row_block_idx = task_idx / (G.k_pgl.cols() / globals::COL_BLOCK); - task_idx %= (G.k_pgl.cols() / globals::COL_BLOCK); - - int col_block_idx = task_idx; - - __shared__ semaphore arrived_k, arrived_v; - init_semaphore(arrived_k, 0, 1); - init_semaphore(arrived_v, 0, 1); - - tma::expect_bytes(arrived_k, sizeof(k_tile)); - tma::expect_bytes(arrived_v, sizeof(v_tile)); - - tma::load_async(k_tile, G.k_pgl[G.peer_rank], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived_k); - tma::load_async(v_tile, G.v_pgl[G.peer_rank], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived_v); - - wait(arrived_k, 0); - wait(arrived_v, 0); - - tma::store_async(G.k_local, k_tile, {batch_idx, depth_idx, row_block_idx, col_block_idx}); - tma::store_async(G.v_local, v_tile, {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -namespace tk_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} - -void tk_fetch_kv( - kittens::py::TKParallelTensor &k_pgl, - kittens::py::TKParallelTensor &v_pgl, - torch::Tensor k_local, - torch::Tensor v_local, - int peer_rank -) { - globals G { - .k_pgl = kittens::py::parallel_tensor_to_pgl(k_pgl), - .v_pgl = kittens::py::parallel_tensor_to_pgl(v_pgl), - .k_local = kittens::py::tensor_to_gl(k_local), - .v_local = kittens::py::tensor_to_gl(v_local), - .peer_rank = peer_rank - }; - kittens::py::launch_kernel(G); -} - -void tk_barrier_fn(kittens::py::TKParallelTensor &barrier) { - tk_barrier::globals G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fetch_kv", &tk_fetch_kv); - m.def("tk_barrier", &tk_barrier_fn); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ring_fetch_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -# --------------------------------------------------------------------------- -# Torch Compiled Local Compute Kernels -# --------------------------------------------------------------------------- -@torch.compile(fullgraph=True, mode="reduce-overhead") -def _local_attn_bhsd( - qh: torch.Tensor, kh: torch.Tensor, vh: torch.Tensor, - scale: float, causal: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - """qh, kh, vh: [B,H,S,D] -> out: [B,S,H,D], lse: [B,H,S]""" - qh = qh.float() - kh = kh.float() - vh = vh.float() - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(qh.size(2), kh.size(2), device=qh.device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - block_lse = torch.logsumexp(scores, dim=-1) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh) - return block_out.transpose(1, 2).contiguous(), block_lse - -@torch.compile(fullgraph=True, mode="reduce-overhead") -def _merge_out_lse_compiled( - out: torch.Tensor, lse: torch.Tensor, - block_out: torch.Tensor, block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(-1) - out = out - torch.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - torch.nn.functional.logsigmoid(lse - block_lse) - return out, lse - - -# --------------------------------------------------------------------------- -# Hot-Path Entrypoint -# --------------------------------------------------------------------------- -@torch.no_grad() -def solution( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - assert world_size == 8, "ThunderKittens kernel assumes NUM_DEVICES=8" - - if softmax_scale is None: - softmax_scale = q.shape[-1] ** -0.5 - - # Re-stride to align with TK GL [B, H, S, D] layout internally, avoiding memory copies. - qh = q.transpose(1, 2).contiguous() - kh = k.transpose(1, 2).contiguous() - vh = v.transpose(1, 2).contiguous() - - B, H, S, D = kh.shape - - if world_size == 1: - block_out, block_lse = _local_attn_bhsd(qh, kh, vh, softmax_scale, causal) - out, lse = block_out.to(torch.float32), block_lse.transpose(-2, -1).unsqueeze(-1) - return out.to(q.dtype) - - ext = _ensure_ext_jit() - - # Align sequence and depth dimensions for TK 64x64 TMA chunks - pad_S = (64 - (S % 64)) % 64 - pad_D = (64 - (D % 64)) % 64 - - if pad_S > 0 or pad_D > 0: - kh_pad = F.pad(kh, (0, pad_D, 0, pad_S)) - vh_pad = F.pad(vh, (0, pad_D, 0, pad_S)) - else: - kh_pad = kh - vh_pad = vh - - # Expose local K and V physically in symmetric memory - k_tk = get_or_create_parallel_tensor(ext, kh_pad.shape, torch.bfloat16, multicast=False) - v_tk = get_or_create_parallel_tensor(ext, vh_pad.shape, torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - k_tk.data_[:kh_pad.numel()].copy_(kh_pad.view(-1)) - v_tk.data_[:vh_pad.numel()].copy_(vh_pad.view(-1)) - - ext.tk_barrier(barrier_tk) - - compute_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream() - - buf_k = [torch.empty_like(kh_pad), torch.empty_like(kh_pad)] - buf_v = [torch.empty_like(vh_pad), torch.empty_like(vh_pad)] - - out, lse = None, None - - # Overlap bootstrap: Prefetch step 1 asynchronously - if world_size > 1: - with torch.cuda.stream(comm_stream): - peer_1 = (rank - 1) % world_size - ext.tk_fetch_kv(k_tk, v_tk, buf_k[1%2], buf_v[1%2], peer_1) - - for step in range(world_size): - skip_compute = causal and (step > rank) - - if not skip_compute: - # Await the overlapped block read from peer rank - compute_stream.wait_stream(comm_stream) - - if step == 0: - cur_k_bhsd = kh - cur_v_bhsd = vh - else: - cur_k_pad = buf_k[step % 2] - cur_v_pad = buf_v[step % 2] - cur_k_bhsd = cur_k_pad[:, :, :S, :D] - cur_v_bhsd = cur_v_pad[:, :, :S, :D] - - block_out, block_lse = _local_attn_bhsd( - qh, cur_k_bhsd, cur_v_bhsd, softmax_scale, causal=(causal and step == 0) - ) - - if out is None: - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(-1) - else: - out, lse = _merge_out_lse_compiled(out, lse, block_out, block_lse) - - # Trigger async peer fetch for next step's blocks, hiding latency behind the next projection - if step + 1 < world_size: - next_peer = (rank - (step + 1)) % world_size - with torch.cuda.stream(comm_stream): - # Ensure the current projection matrix has fully resolved the background buffer - comm_stream.wait_stream(compute_stream) - ext.tk_fetch_kv(k_tk, v_tk, buf_k[(step + 1) % 2], buf_v[(step + 1) % 2], next_peer) - - torch.cuda.current_stream().synchronize() - - # Prevent sequential layer iterations from immediately overwriting PGL tensors another peer is reading - ext.tk_barrier(barrier_tk) - - return out.to(q.dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_parallelkittens.py deleted file mode 100755 index 0f09a45..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_parallelkittens.py +++ /dev/null @@ -1,417 +0,0 @@ -import os -import math -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded ThunderKittens C++ / CUDA Source -# --------------------------------------------------------------------------- - -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include - -using namespace kittens; - -// ============================================================================ -// TMA PEER COPY: Direct P2P DMA pull from symmetric memory -// ============================================================================ -namespace peer_copy { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 32; // Minimal threads; TMA offloads copy -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - int peer_idx; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // Asynchronous DMA pull from peer - tma::load_async(tile, G.input[G.peer_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - wait(arrived, 0); - - // Asynchronous DMA push to local buffer - tma::store_async(G.output[G.dev_idx], tile, {batch_idx, depth_idx, row_block_idx, col_block_idx}); - tma::store_commit_group(); - tma::store_async_wait(); -} -} // namespace peer_copy - -void launch_peer_copy( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - int peer_idx -) { - peer_copy::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .peer_idx = peer_idx, - .dev_idx = input.local_rank_ - }; - kittens::py::launch_kernel(G); -} - -// ============================================================================ -// BARRIER: Cluster synchronization -// ============================================================================ -namespace sync_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} - -void launch_barrier(kittens::py::TKParallelTensor &barrier) { - sync_barrier::globals G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(G); -} - -// ============================================================================ -// FUSED MERGE LSE: Numerically stable accumulation over chunks -// ============================================================================ -namespace merge_lse { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - float* out; - float* lse; - const bf16* block_out; - const float* block_lse; - int numel_out; - int D; - int H; - int S; -}; - -__device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < G.numel_out) { - int lse_idx = idx / G.D; - - float current_lse = G.lse[lse_idx]; - float new_lse = G.block_lse[lse_idx]; - - float current_out = G.out[idx]; - float new_out = __bfloat162float(G.block_out[idx]); - - bool new_is_inf = isinf(new_lse) && new_lse < 0; - bool cur_is_inf = isinf(current_lse) && current_lse < 0; - - float sig; - if (new_is_inf && cur_is_inf) { - sig = 0.0f; - } else if (cur_is_inf) { - sig = 1.0f; - } else if (new_is_inf) { - sig = 0.0f; - } else { - float diff_lse = new_lse - current_lse; - sig = 1.0f / (1.0f + expf(-diff_lse)); - } - - float updated_out = current_out - sig * (current_out - new_out); - G.out[idx] = updated_out; - - // Single thread per head-dim group updates the shared LSE value - if ((idx % G.D) == 0) { - if (new_is_inf && cur_is_inf) { - // remains -inf - } else if (cur_is_inf) { - G.lse[lse_idx] = new_lse; - } else if (new_is_inf) { - // remains current - } else { - float diff = current_lse - new_lse; - float ls = -log1pf(expf(-diff)); - G.lse[lse_idx] = current_lse - ls; - } - } - } -} -} // namespace merge_lse - -void launch_merge( - torch::Tensor out, - torch::Tensor lse, - torch::Tensor block_out, - torch::Tensor block_lse, - int D, int H, int S -) { - merge_lse::globals G { - .out = out.data_ptr(), - .lse = lse.data_ptr(), - .block_out = reinterpret_cast(block_out.data_ptr()), - .block_lse = block_lse.data_ptr(), - .numel_out = static_cast(out.numel()), - .D = D, - .H = H, - .S = S - }; - dim3 grid((G.numel_out + merge_lse::config::NUM_THREADS - 1) / merge_lse::config::NUM_THREADS); - kittens::py::launch_kernel(G, grid); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_peer_copy", &launch_peer_copy); - m.def("tk_barrier", &launch_barrier); - m.def("tk_merge", &launch_merge); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ring_attn_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - if dist.is_initialized() and dist.get_rank() == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -# --------------------------------------------------------------------------- -# Python Attention implementation interacting with TK extensions -# --------------------------------------------------------------------------- - -def _local_attn( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - scale: float, causal: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes dense local attention, returning block outputs structured for the TK merge.""" - qh = q.transpose(1, 2).float() - kh = k.transpose(1, 2).float() - vh = v.transpose(1, 2).float() - - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(q.size(1), k.size(1), device=q.device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - - block_lse = torch.logsumexp(scores, dim=-1) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous() - - # Transpose and guarantee contiguous memory structure for predictable TK merge addressing - block_lse = block_lse.transpose(-2, -1).contiguous() - return block_out.to(torch.bfloat16), block_lse - - -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - num_heads: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - tp_group: Optional[dist.ProcessGroup] = None, - cp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - tp_group = tp_group or dist.group.WORLD - cp_group = cp_group or dist.group.WORLD - - tp_size = dist.get_world_size(tp_group) - cp_size = dist.get_world_size(cp_group) - heads_local = num_heads // tp_size - head_dim = w_qkv.shape[0] // 3 // heads_local - if softmax_scale is None: - softmax_scale = head_dim ** -0.5 - - B, S = hidden_states.shape[:2] - qkv = F.linear(hidden_states, w_qkv).view(B, S, 3, heads_local, head_dim) - q, k, v = qkv.unbind(dim=2) - - # Fast-path for single CP rank - if cp_size == 1: - block_out, _ = _local_attn(q.contiguous(), k.contiguous(), v.contiguous(), float(softmax_scale), causal) - out = block_out.to(q.dtype) - out = F.linear(out.reshape(B, S, -1), w_o) - if tp_size > 1: - dist.all_reduce(out, op=dist.ReduceOp.SUM, group=tp_group) - return out - - ext = _ensure_ext_jit() - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - # Determine TMA optimal padding for Parallel Tensors - flat_k = k.view(-1) - flat_v = v.view(-1) - n = flat_k.numel() - - TILE_ELEMS = 16 * 128 - num_tiles = (n + TILE_ELEMS - 1) // TILE_ELEMS - r, c = 16, 128 * num_tiles - - # Static wrappers holding current rank's source data (peers pull from these) - k_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - v_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - - # Double-buffers mapped natively on device to pipeline peer fetching - f1_k_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - f1_v_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - f2_k_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - f2_v_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - - barrier_tk = get_or_create_barrier(ext, num_devices=8) - - # Initialize symmetric buffers and globally barrier - k_tk.data_.view(-1)[:n].copy_(flat_k) - v_tk.data_.view(-1)[:n].copy_(flat_v) - ext.tk_barrier(barrier_tk) - - # Persistent output tracking allocations - out = torch.zeros((B, S, heads_local, head_dim), dtype=torch.float32, device=q.device) - lse = torch.full((B, S, heads_local), float("-inf"), dtype=torch.float32, device=q.device) - - copy_stream = torch.cuda.Stream() - local_cp = dist.get_rank(cp_group) - local_tp = dist.get_rank(tp_group) - - # Issue initial background fetch for step 1 - if cp_size > 1: - peer_cp = (local_cp - 1 + cp_size) % cp_size - peer_global = peer_cp * tp_size + local_tp - with torch.cuda.stream(copy_stream): - ext.tk_peer_copy(f1_k_tk, k_tk, peer_global) - ext.tk_peer_copy(f1_v_tk, v_tk, peer_global) - - # Process local step 0 - if (not causal) or 0 <= local_cp: - block_out, block_lse = _local_attn(q, k, v, float(softmax_scale), causal=causal) - ext.tk_merge(out, lse, block_out, block_lse, head_dim, heads_local, S) - - # Pipelined schedule over peers - for step in range(1, cp_size): - torch.cuda.current_stream().wait_stream(copy_stream) - - # Unwrap current step's buffers correctly - cur_k = f1_k_tk.data_.view(-1)[:n].view(B, S, heads_local, head_dim) - cur_v = f1_v_tk.data_.view(-1)[:n].view(B, S, heads_local, head_dim) - - # Launch background fetch for step n+1 into alternating buffer - if step + 1 < cp_size: - peer_cp = (local_cp - step - 1 + cp_size) % cp_size - peer_global = peer_cp * tp_size + local_tp - with torch.cuda.stream(copy_stream): - ext.tk_peer_copy(f2_k_tk, k_tk, peer_global) - ext.tk_peer_copy(f2_v_tk, v_tk, peer_global) - - # Execute attention on fetched KV and immediately device-fuse to accumulators - if (not causal) or step <= local_cp: - block_out, block_lse = _local_attn(q, cur_k, cur_v, float(softmax_scale), causal=False) - ext.tk_merge(out, lse, block_out, block_lse, head_dim, heads_local, S) - - # Swap buffers for the next cycle - f1_k_tk, f2_k_tk = f2_k_tk, f1_k_tk - f1_v_tk, f2_v_tk = f2_v_tk, f1_v_tk - - # Finalize Out -> Row Parallel Projection -> TP sum - out = out.to(q.dtype) - out = F.linear(out.reshape(B, S, -1), w_o) - if tp_size > 1: - dist.all_reduce(out, op=dist.ReduceOp.SUM, group=tp_group) - - return out \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_parallelkittens.py deleted file mode 100755 index 8b059a2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_parallelkittens.py +++ /dev/null @@ -1,351 +0,0 @@ -import os -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import get_or_create_parallel_tensor - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: TMA KV Fetch + Fused LSE Merge -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include -#include - -using namespace kittens; - -namespace tma_fetch { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; - }; - - template - struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = D; - - using tile_t = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - using local_layout = gl; - - parallel_layout K_pgl; - parallel_layout V_pgl; - local_layout next_K; - local_layout next_V; - int peer_idx; - - __host__ inline dim3 grid() const { - return dim3(next_K.batch() * next_K.depth() * (next_K.rows() / ROW_BLOCK_SIZE)); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(2 * sizeof(tile_t) + 1024); - } - }; - - template - __device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - typename globals::tile_t &k_tile = allocator.allocate::tile_t>(); - typename globals::tile_t &v_tile = allocator.allocate::tile_t>(); - - int task_idx = blockIdx.x; - int c_idx = 0; // col dimension is exactly D, handled by 1 block - int r_idx = task_idx % (G.next_K.rows() / globals::ROW_BLOCK_SIZE); task_idx /= (G.next_K.rows() / globals::ROW_BLOCK_SIZE); - int d_idx = task_idx % G.next_K.depth(); task_idx /= G.next_K.depth(); - int b_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - - tma::expect_bytes(arrived, sizeof(k_tile) + sizeof(v_tile)); - tma::load_async(k_tile, G.K_pgl[G.peer_idx], {b_idx, d_idx, r_idx, c_idx}, arrived); - tma::load_async(v_tile, G.V_pgl[G.peer_idx], {b_idx, d_idx, r_idx, c_idx}, arrived); - - wait(arrived, 0); - - tma::store_async(G.next_K, k_tile, {b_idx, d_idx, r_idx, c_idx}); - tma::store_async(G.next_V, v_tile, {b_idx, d_idx, r_idx, c_idx}); - } -} - -__global__ void merge_lse_kernel( - float* __restrict__ out, - const float* __restrict__ lse, - const __nv_bfloat16* __restrict__ block_out, - const float* __restrict__ block_lse, - int num_elements, - int head_dim, - int seq_len, - int num_heads -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < num_elements) { - int d_idx = idx % head_dim; - int tmp = idx / head_dim; - int h_idx = tmp % num_heads; - tmp = tmp / num_heads; - int s_idx = tmp % seq_len; - int b_idx = tmp / seq_len; - - int lse_idx = (b_idx * num_heads + h_idx) * seq_len + s_idx; - - float bo_val = __bfloat162float(block_out[idx]); - float bl_val = block_lse[lse_idx]; - - float o_val = out[idx]; - float l_val = lse[lse_idx]; - - float diff = bl_val - l_val; - float sig = 1.0f / (1.0f + expf(-diff)); - float new_o = o_val - sig * (o_val - bo_val); - - out[idx] = new_o; - } -} - -template -void launch_fetch_impl( - kittens::py::TKParallelTensor &K_pgl, - kittens::py::TKParallelTensor &V_pgl, - torch::Tensor next_K, - torch::Tensor next_V, - int peer_idx -) { - auto next_K_tk = kittens::py::tensor_to_gl::local_layout>(next_K); - auto next_V_tk = kittens::py::tensor_to_gl::local_layout>(next_V); - - tma_fetch::globals G { - .K_pgl = kittens::py::parallel_tensor_to_pgl::parallel_layout>(K_pgl), - .V_pgl = kittens::py::parallel_tensor_to_pgl::parallel_layout>(V_pgl), - .next_K = next_K_tk, - .next_V = next_V_tk, - .peer_idx = peer_idx - }; - kittens::py::launch_kernel, tma_fetch::kernel>(G); -} - -void launch_fetch( - kittens::py::TKParallelTensor &K_pgl, - kittens::py::TKParallelTensor &V_pgl, - torch::Tensor next_K, - torch::Tensor next_V, - int peer_idx, - int head_dim -) { - if (head_dim == 64) launch_fetch_impl<64>(K_pgl, V_pgl, next_K, next_V, peer_idx); - else if (head_dim == 128) launch_fetch_impl<128>(K_pgl, V_pgl, next_K, next_V, peer_idx); - else TORCH_CHECK(false, "head_dim must be 64 or 128 for optimized TMA fetch"); -} - -void launch_merge( - torch::Tensor out, - torch::Tensor lse, - torch::Tensor block_out, - torch::Tensor block_lse -) { - int num_elements = out.numel(); - int head_dim = out.size(3); - int num_heads = out.size(2); - int seq_len = out.size(1); - - int threads = 256; - int blocks = (num_elements + threads - 1) / threads; - - merge_lse_kernel<<>>( - out.data_ptr(), - lse.data_ptr(), - reinterpret_cast(block_out.data_ptr()), - block_lse.data_ptr(), - num_elements, - head_dim, - seq_len, - num_heads - ); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("launch_fetch", &launch_fetch); - m.def("launch_merge", &launch_merge); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", "--use_fast_math", "--expt-extended-lambda", "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", "-Xcompiler=-fno-strict-aliasing", "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ring_attn_ext", CUDA_SRC, extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[os.path.join(TK_ROOT, "include"), os.path.join(TK_ROOT, "prototype")], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -# --------------------------------------------------------------------------- -# Torch + TK Python Interop -# --------------------------------------------------------------------------- - -def _local_attn_math(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, causal: bool) -> Tuple[torch.Tensor, torch.Tensor]: - qh = q.transpose(1, 2).float() - kh = k.transpose(1, 2).float() - vh = v.transpose(1, 2).float() - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(q.size(1), k.size(1), device=q.device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - block_lse = torch.logsumexp(scores, dim=-1) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous() - return block_out, block_lse - - -def _pp_recv_forward(pp_group: dist.ProcessGroup, shape: Tuple[int, ...], dtype: torch.dtype, device: torch.device) -> torch.Tensor: - prev_rank = dist.get_global_rank(pp_group, (dist.get_rank(pp_group) - 1) % dist.get_world_size(pp_group)) - buf = torch.empty(shape, dtype=dtype, device=device) - dist.irecv(buf, prev_rank, group=pp_group).wait() - return buf - - -def _pp_send_forward(pp_group: dist.ProcessGroup, tensor: torch.Tensor) -> None: - next_rank = dist.get_global_rank(pp_group, (dist.get_rank(pp_group) + 1) % dist.get_world_size(pp_group)) - dist.isend(tensor.contiguous(), next_rank, group=pp_group).wait() - - -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - num_heads: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - cp_group: Optional[dist.ProcessGroup] = None, - pp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - ext = _ensure_ext_jit() - cp_group = cp_group or dist.group.WORLD - - is_first, is_last = True, True - if pp_group is not None and dist.get_world_size(pp_group) > 1: - pp_rank, pp_size = dist.get_rank(pp_group), dist.get_world_size(pp_group) - is_first, is_last = (pp_rank == 0), (pp_rank == pp_size - 1) - - # 1. Pipeline-parallel step boundary - stage_input = hidden_states if is_first else _pp_recv_forward( - pp_group, tuple(hidden_states.shape), hidden_states.dtype, hidden_states.device - ) - - # 2. Extract Q, K, V - B, S, D_hidden = stage_input.shape - head_dim = w_qkv.shape[0] // 3 // num_heads - scale = float(softmax_scale if softmax_scale is not None else head_dim ** -0.5) - - qkv = F.linear(stage_input, w_qkv).view(B, S, 3, num_heads, head_dim) - q, k, v = qkv.unbind(dim=2) - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - # 3. Context-Parallel Ring Attention Setup - cp_rank = dist.get_rank(cp_group) - cp_size = dist.get_world_size(cp_group) - - if cp_size == 1: - block_out, _ = _local_attn_math(q, k, v, scale, causal) - stage_output = F.linear(block_out.to(q.dtype).reshape(B, S, -1), w_o) - else: - # Pre-pad to 16 for TK blocks - padded_s = ((S + 15) // 16) * 16 - - # TK Symmetric Memory mapping for TMA peer fetch - K_tk = get_or_create_parallel_tensor(ext, (B, num_heads, padded_s, head_dim), torch.bfloat16, multicast=False) - V_tk = get_or_create_parallel_tensor(ext, (B, num_heads, padded_s, head_dim), torch.bfloat16, multicast=False) - - k_tr = k.transpose(1, 2).contiguous() - v_tr = v.transpose(1, 2).contiguous() - - K_tk.data_[:, :, :S, :].copy_(k_tr) - V_tk.data_[:, :, :S, :].copy_(v_tr) - - # Ensure all peers have flushed memory - dist.barrier(cp_group) - - fetch_stream = torch.cuda.Stream() - out, lse = None, None - k_current, v_current = k_tr, v_tr - - for step in range(cp_size): - next_k, next_v = None, None - - # Initiate Async TMA Fetch of the NEXT chunk - if step + 1 != cp_size: - peer_cp_rank = (cp_rank - (step + 1)) % cp_size - peer_global_rank = dist.get_global_rank(cp_group, peer_cp_rank) - - next_k_full = torch.empty((B, num_heads, padded_s, head_dim), dtype=torch.bfloat16, device=q.device) - next_v_full = torch.empty((B, num_heads, padded_s, head_dim), dtype=torch.bfloat16, device=q.device) - - with torch.cuda.stream(fetch_stream): - ext.launch_fetch(K_tk, V_tk, next_k_full, next_v_full, peer_global_rank % 8, head_dim) - - next_k = next_k_full[:, :, :S, :] - next_v = next_v_full[:, :, :S, :] - - # Overlapped Math compute & Merge - if (not causal) or step <= cp_rank: - k_loc = k_current.transpose(1, 2) - v_loc = v_current.transpose(1, 2) - block_out, block_lse = _local_attn_math(q, k_loc, v_loc, scale, causal=(causal and step == 0)) - - if out is None: - out = block_out.to(torch.float32).clone() - lse = block_lse.to(torch.float32).clone() - else: - ext.launch_merge(out, lse, block_out, block_lse) - lse.copy_(lse - F.logsigmoid(lse - block_lse)) - - # Synchronize background stream before next iteration - if step + 1 != cp_size: - torch.cuda.current_stream().wait_stream(fetch_stream) - k_current, v_current = next_k, next_v - - stage_output = F.linear(out.to(q.dtype).reshape(B, S, -1), w_o) - - # 4. Pipeline-parallel step boundary - if not is_last and pp_group is not None: - _pp_send_forward(pp_group, stage_output) - - return stage_output \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_parallelkittens.py deleted file mode 100755 index 4f9ba3f..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_parallelkittens.py +++ /dev/null @@ -1,418 +0,0 @@ -import os -import math -from typing import Optional, Tuple - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -# --------------------------------------------------------------------------- -# Embedded .cu source for ThunderKittens kernels -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace shared { - struct barrier_config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - struct barrier_globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; - }; - __device__ inline void barrier_kernel(const barrier_globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); - } -} - -namespace shift_all { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - using layout = pgl, NUM_DEVICES, false>; - layout k_in, v_in, dk_in, dv_in; - layout k_out, v_out, dk_out, dv_out; - int dev_idx; - int cp_size; - int numel; - __host__ inline dim3 grid() const { return dim3((numel + config::NUM_THREADS * 2 - 1) / (config::NUM_THREADS * 2)); } - }; - __device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx * 2 < G.numel) { - int cp_rank = G.dev_idx % G.cp_size; - int dp_rank = G.dev_idx / G.cp_size; - int prev_dev = dp_rank * G.cp_size + (cp_rank - 1 + G.cp_size) % G.cp_size; - - *(int*)(&G.k_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.k_in[prev_dev].data[idx * 2]); - *(int*)(&G.v_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.v_in[prev_dev].data[idx * 2]); - *(int*)(&G.dk_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.dk_in[prev_dev].data[idx * 2]); - *(int*)(&G.dv_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.dv_in[prev_dev].data[idx * 2]); - } - } -} - -namespace shift_dkv { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - using layout = pgl, NUM_DEVICES, false>; - layout dk_in, dv_in; - layout dk_out, dv_out; - int dev_idx; - int cp_size; - int numel; - __host__ inline dim3 grid() const { return dim3((numel + config::NUM_THREADS * 2 - 1) / (config::NUM_THREADS * 2)); } - }; - __device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx * 2 < G.numel) { - int cp_rank = G.dev_idx % G.cp_size; - int dp_rank = G.dev_idx / G.cp_size; - int prev_dev = dp_rank * G.cp_size + (cp_rank - 1 + G.cp_size) % G.cp_size; - - *(int*)(&G.dk_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.dk_in[prev_dev].data[idx * 2]); - *(int*)(&G.dv_out[G.dev_idx].data[idx * 2]) = *(int*)(&G.dv_in[prev_dev].data[idx * 2]); - } - } -} - -namespace dp_all_reduce { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - using layout = pgl, NUM_DEVICES, false>; - layout dq, dk, dv; - layout dq_out, dk_out, dv_out; - int dev_idx; - int cp_size; - int dp_size; - int numel; - __host__ inline dim3 grid() const { return dim3((numel + config::NUM_THREADS * 2 - 1) / (config::NUM_THREADS * 2)); } - }; - __device__ inline void kernel(const globals &G) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx * 2 < G.numel) { - int cp_rank = G.dev_idx % G.cp_size; - - float2 sum_dq = {0.f, 0.f}; - float2 sum_dk = {0.f, 0.f}; - float2 sum_dv = {0.f, 0.f}; - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - for (int d = 0; d < G.dp_size; ++d) { - int peer = d * G.cp_size + cp_rank; - - bf16_2 val_dq = *(bf16_2*)(&G.dq[peer].data[idx * 2]); - float2 f_dq = __bfloat1622float2(val_dq); - sum_dq.x += f_dq.x; sum_dq.y += f_dq.y; - - bf16_2 val_dk = *(bf16_2*)(&G.dk[peer].data[idx * 2]); - float2 f_dk = __bfloat1622float2(val_dk); - sum_dk.x += f_dk.x; sum_dk.y += f_dk.y; - - bf16_2 val_dv = *(bf16_2*)(&G.dv[peer].data[idx * 2]); - float2 f_dv = __bfloat1622float2(val_dv); - sum_dv.x += f_dv.x; sum_dv.y += f_dv.y; - } - - sum_dq.x /= G.dp_size; sum_dq.y /= G.dp_size; - sum_dk.x /= G.dp_size; sum_dk.y /= G.dp_size; - sum_dv.x /= G.dp_size; sum_dv.y /= G.dp_size; - - *(bf16_2*)(&G.dq_out[G.dev_idx].data[idx * 2]) = __float22bfloat162_rn(sum_dq); - *(bf16_2*)(&G.dk_out[G.dev_idx].data[idx * 2]) = __float22bfloat162_rn(sum_dk); - *(bf16_2*)(&G.dv_out[G.dev_idx].data[idx * 2]) = __float22bfloat162_rn(sum_dv); - #endif - } - } -} - -// Host entrypoints -void tk_barrier(kittens::py::TKParallelTensor &barrier) { - shared::barrier_globals bg { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(bg); -} - -void tk_shift_all( - kittens::py::TKParallelTensor &k_out, kittens::py::TKParallelTensor &v_out, kittens::py::TKParallelTensor &dk_out, kittens::py::TKParallelTensor &dv_out, - kittens::py::TKParallelTensor &k_in, kittens::py::TKParallelTensor &v_in, kittens::py::TKParallelTensor &dk_in, kittens::py::TKParallelTensor &dv_in, - int cp_size, int actual_numel -) { - shift_all::globals g { - .k_in = kittens::py::parallel_tensor_to_pgl(k_in), - .v_in = kittens::py::parallel_tensor_to_pgl(v_in), - .dk_in = kittens::py::parallel_tensor_to_pgl(dk_in), - .dv_in = kittens::py::parallel_tensor_to_pgl(dv_in), - .k_out = kittens::py::parallel_tensor_to_pgl(k_out), - .v_out = kittens::py::parallel_tensor_to_pgl(v_out), - .dk_out = kittens::py::parallel_tensor_to_pgl(dk_out), - .dv_out = kittens::py::parallel_tensor_to_pgl(dv_out), - .dev_idx = k_in.local_rank_, - .cp_size = cp_size, - .numel = actual_numel - }; - kittens::py::launch_kernel(g); -} - -void tk_shift_dkv( - kittens::py::TKParallelTensor &dk_out, kittens::py::TKParallelTensor &dv_out, - kittens::py::TKParallelTensor &dk_in, kittens::py::TKParallelTensor &dv_in, - int cp_size, int actual_numel -) { - shift_dkv::globals g { - .dk_in = kittens::py::parallel_tensor_to_pgl(dk_in), - .dv_in = kittens::py::parallel_tensor_to_pgl(dv_in), - .dk_out = kittens::py::parallel_tensor_to_pgl(dk_out), - .dv_out = kittens::py::parallel_tensor_to_pgl(dv_out), - .dev_idx = dk_in.local_rank_, - .cp_size = cp_size, - .numel = actual_numel - }; - kittens::py::launch_kernel(g); -} - -void tk_dp_all_reduce( - kittens::py::TKParallelTensor &dq, kittens::py::TKParallelTensor &dk, kittens::py::TKParallelTensor &dv, - kittens::py::TKParallelTensor &dq_out, kittens::py::TKParallelTensor &dk_out, kittens::py::TKParallelTensor &dv_out, - int cp_size, int dp_size, int actual_numel -) { - dp_all_reduce::globals g { - .dq = kittens::py::parallel_tensor_to_pgl(dq), - .dk = kittens::py::parallel_tensor_to_pgl(dk), - .dv = kittens::py::parallel_tensor_to_pgl(dv), - .dq_out = kittens::py::parallel_tensor_to_pgl(dq_out), - .dk_out = kittens::py::parallel_tensor_to_pgl(dk_out), - .dv_out = kittens::py::parallel_tensor_to_pgl(dv_out), - .dev_idx = dq.local_rank_, - .cp_size = cp_size, - .dp_size = dp_size, - .numel = actual_numel - }; - kittens::py::launch_kernel(g); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_barrier", &tk_barrier); - m.def("tk_shift_all", &tk_shift_all); - m.def("tk_shift_dkv", &tk_shift_dkv); - m.def("tk_dp_all_reduce", &tk_dp_all_reduce); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", "--use_fast_math", "--expt-extended-lambda", "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", "-Xcompiler=-fno-strict-aliasing", "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_ring_attn_bwd", CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens"), "include"), - os.path.join(os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens"), "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - -def _local_attn_backward( - dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor, softmax_lse: torch.Tensor, - scale: float, causal: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - qh = q.transpose(1, 2).float() - kh = k.transpose(1, 2).float() - vh = v.transpose(1, 2).float() - doh = dout.transpose(1, 2).float() - outh = out.transpose(1, 2).float() - - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - sq, sk = q.size(1), k.size(1) - mask = torch.triu(torch.ones(sq, sk, device=q.device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - - probs = torch.exp(scores - softmax_lse) - dP = torch.matmul(doh, vh.transpose(-2, -1)) - row_dot = (doh * outh).sum(dim=-1, keepdim=True) - dS = probs * (dP - row_dot) - - dQ = torch.matmul(dS, kh) * scale - dK = torch.matmul(dS.transpose(-2, -1), qh) * scale - dV = torch.matmul(probs.transpose(-2, -1), doh) - - return ( - dQ.transpose(1, 2).contiguous(), - dK.transpose(1, 2).contiguous(), - dV.transpose(1, 2).contiguous(), - ) - -def alloc_pair(ext, padded_size, base_offset): - # Differentiate shape allocations to circumvent caching identical underlying buffers - t0 = get_or_create_parallel_tensor(ext, (padded_size + base_offset,), torch.bfloat16, False) - t1 = get_or_create_parallel_tensor(ext, (padded_size + base_offset + 1,), torch.bfloat16, False) - return [t0, t1] - -@torch.no_grad() -def solution( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - cp_group: Optional[dist.ProcessGroup] = None, - dp_group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - ext = _ensure_ext_jit() - - world = dist.get_world_size() - assert world == 8, "This ThunderKittens integration expects an 8-GPU domain (NUM_DEVICES=8)." - - cp_group = cp_group or dist.group.WORLD - cp_size = dist.get_world_size(cp_group) - cp_rank = dist.get_rank(cp_group) - - if softmax_scale is None: - softmax_scale = q.shape[-1] ** -0.5 - - actual_numel = q.numel() - padded = ((actual_numel + 511) // 512) * 512 - shape = q.shape - - # Pre-allocate double buffers to safely fuse computation + pipeline shifts. - tk_k = alloc_pair(ext, padded, 10) - tk_v = alloc_pair(ext, padded, 20) - tk_dk = alloc_pair(ext, padded, 30) - tk_dv = alloc_pair(ext, padded, 40) - - tk_dq = get_or_create_parallel_tensor(ext, (padded + 50,), torch.bfloat16, False) - - barrier_tk = get_or_create_barrier(ext, num_devices=8) - - # Initialize 0th index with current rank's K and V - tk_k[0].data_[:actual_numel].view(shape).copy_(k) - tk_v[0].data_[:actual_numel].view(shape).copy_(v) - - lse_4d = softmax_lse.unsqueeze(-1) - - for step in range(cp_size): - cur_idx = step % 2 - nxt_idx = (step + 1) % 2 - - v_k_cur = tk_k[cur_idx].data_[:actual_numel].view(shape) - v_v_cur = tk_v[cur_idx].data_[:actual_numel].view(shape) - v_dk_cur = tk_dk[cur_idx].data_[:actual_numel].view(shape) - v_dv_cur = tk_dv[cur_idx].data_[:actual_numel].view(shape) - v_dq = tk_dq.data_[:actual_numel].view(shape) - - if step <= cp_rank or not causal: - block_dq, block_dk, block_dv = _local_attn_backward( - dout, q, v_k_cur, v_v_cur, out, lse_4d, float(softmax_scale), causal=(causal and step == 0) - ) - if step == 0: - v_dq.copy_(block_dq) - v_dk_cur.copy_(block_dk) - v_dv_cur.copy_(block_dv) - else: - v_dq.add_(block_dq) - # Adds directly into in-place TK buffer (accumulating received gradients + local computed) - v_dk_cur.copy_(block_dk + v_dk_cur) - v_dv_cur.copy_(block_dv + v_dv_cur) - - ext.tk_barrier(barrier_tk) - - # P2P rotate buffers to adjacent peer logic completely device-side - if step + 1 != cp_size: - ext.tk_shift_all( - tk_k[nxt_idx], tk_v[nxt_idx], tk_dk[nxt_idx], tk_dv[nxt_idx], - tk_k[cur_idx], tk_v[cur_idx], tk_dk[cur_idx], tk_dv[cur_idx], - cp_size, actual_numel - ) - else: - ext.tk_shift_dkv( - tk_dk[nxt_idx], tk_dv[nxt_idx], - tk_dk[cur_idx], tk_dv[cur_idx], - cp_size, actual_numel - ) - - ext.tk_barrier(barrier_tk) - - final_idx = cp_size % 2 - dp_size = dist.get_world_size(dp_group) if dp_group is not None else 1 - - if dp_size > 1: - tk_dq_out = get_or_create_parallel_tensor(ext, (padded + 60,), torch.bfloat16, False) - tk_dk_out = get_or_create_parallel_tensor(ext, (padded + 70,), torch.bfloat16, False) - tk_dv_out = get_or_create_parallel_tensor(ext, (padded + 80,), torch.bfloat16, False) - - ext.tk_barrier(barrier_tk) - ext.tk_dp_all_reduce( - tk_dq, tk_dk[final_idx], tk_dv[final_idx], - tk_dq_out, tk_dk_out, tk_dv_out, - cp_size, dp_size, actual_numel - ) - ext.tk_barrier(barrier_tk) - - final_dq = tk_dq_out.data_[:actual_numel].view(shape).clone() - final_dk = tk_dk_out.data_[:actual_numel].view(shape).clone() - final_dv = tk_dv_out.data_[:actual_numel].view(shape).clone() - else: - final_dq = tk_dq.data_[:actual_numel].view(shape).clone() - final_dk = tk_dk[final_idx].data_[:actual_numel].view(shape).clone() - final_dv = tk_dv[final_idx].data_[:actual_numel].view(shape).clone() - - return final_dq, final_dk, final_dv \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_parallelkittens.py deleted file mode 100755 index f0b8aff..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_parallelkittens.py +++ /dev/null @@ -1,339 +0,0 @@ -import os -from typing import Optional - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (Fused SigLIP with Peer TMA + Barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace siglip { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 4; - static constexpr int NUM_WARPGROUPS = 2; // 8 warps = 256 threads - static constexpr int NUM_THREADS = NUM_WARPGROUPS * WARPGROUP_WARPS * WARP_THREADS; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - int B; - int flat_D; - float scale; - float bias; - - using shared_tile = st_bf<64, 64>; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout image; - parallel_layout text; - float* loss_out; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((B + 63) / 64, (B + 63) / 64, NUM_DEVICES); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(2 * sizeof(shared_tile) + sizeof(st_fl<64, 64>) + 1024); - } -}; - -__device__ inline float logsigmoid_neg(float x) { - // computes -logsigmoid(x) = log(1 + exp(-x)) - float abs_x = fabsf(x); - float val = log1pf(expf(-abs_x)); - return (x < 0.0f) ? (-x + val) : val; -} - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - - globals::shared_tile &a_smem = allocator.allocate(); - globals::shared_tile &b_smem = allocator.allocate(); - st_fl<64, 64> &c_smem = allocator.allocate>(); - - int row_idx = blockIdx.x; - int col_idx = blockIdx.y; - int target_dev = blockIdx.z; - - int num_k_blocks = G.flat_D / 64; - - rt_fl<64, 64> acc; - zero(acc); - - __shared__ semaphore arrived; - - for (int k = 0; k < num_k_blocks; ++k) { - __syncthreads(); // Ensure smem is consumed from previous loop before overwriting - if (threadIdx.x == 0) { - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, 2 * sizeof(globals::shared_tile)); - tma::load_async(a_smem, G.image[G.dev_idx], {0, 0, row_idx, k}, arrived); - tma::load_async(b_smem, G.text[target_dev], {0, 0, col_idx, k}, arrived); - } - __syncthreads(); // Ensure initialized semaphore is visible to all - wait(arrived, 0); - - rt_bf<64, 64> a_reg; - rt_bf<64, 64> b_reg; - load(a_reg, a_smem); - load(b_reg, b_smem); - - mma_ABt(acc, a_reg, b_reg, acc); - } - - store(c_smem, acc); - __syncthreads(); - - float block_loss = 0.0f; - float* c_ptr = (float*)&c_smem; - - // Element-wise log-sigmoid loss - for (int i = threadIdx.x; i < 4096; i += blockDim.x) { - int r = i / 64; - int c = i % 64; - - int global_r = row_idx * 64 + r; - int global_c = col_idx * 64 + c; - - if (global_r < G.B && global_c < G.B) { - float val = c_ptr[i]; - float logits = G.scale * val + G.bias; - // Diagonals of local device match are positive pairs - float label = (target_dev == G.dev_idx && global_r == global_c) ? 1.0f : -1.0f; - - float x = label * logits; - block_loss += logsigmoid_neg(x); - } - } - - // Warp and Block reduction - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - block_loss += __shfl_down_sync(0xffffffff, block_loss, offset); - } - - __shared__ float warp_sums[8]; - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - - if (lane_id == 0) { - warp_sums[warp_id] = block_loss; - } - __syncthreads(); - - if (warp_id == 0) { - float val = (lane_id < (blockDim.x / 32)) ? warp_sums[lane_id] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (lane_id == 0) { - int block_id = (blockIdx.x * gridDim.y + blockIdx.y) * gridDim.z + blockIdx.z; - G.loss_out[block_id] = val; - } - } -} - -} // namespace siglip - - -namespace siglip_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace siglip_barrier - - -void entrypoint( - kittens::py::TKParallelTensor &image, - kittens::py::TKParallelTensor &text, - kittens::py::TKParallelTensor &loss_out, - kittens::py::TKParallelTensor &barrier, - int B, - int flat_D, - float scale, - float bias -) { - kittens::py::parallel_tensor_check(image, text, loss_out, barrier); - - siglip::globals siglip_G { - .B = B, - .flat_D = flat_D, - .scale = scale, - .bias = bias, - .image = kittens::py::parallel_tensor_to_pgl(image), - .text = kittens::py::parallel_tensor_to_pgl(text), - .loss_out = reinterpret_cast(loss_out.data_.data_ptr()), - .dev_idx = image.local_rank_ - }; - - siglip_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Global barrier to guarantee all data has safely reached PGL peers before reads. - kittens::py::launch_kernel(barrier_G); - - // Cross-rank local TMA loads -> Fused Compute -> Reduction logic - kittens::py::launch_kernel(siglip_G); - - // Barrier to ensure TMA loads are done before tensors are freed/rewritten in dynamic models. - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_siglip", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_siglip_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - image_features: torch.Tensor, - text_features: torch.Tensor, - logit_scale: float, - logit_bias: float = 0.0, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - Per-rank SigLIP loss with text features fully distributed via ThunderKittens TMA reads. - Replaces host-driven O(N) bidir loop with direct NVLink peer memory accesses. - """ - assert image_features.is_cuda and image_features.is_contiguous() - assert text_features.is_cuda and text_features.is_contiguous() - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - assert world_size == NUM_DEVICES, f"This kernel is fixed for NUM_DEVICES={NUM_DEVICES}" - - ext = _ensure_ext_jit() - - B, D = image_features.shape - - # Pad to strictly TK supported multiples for memory tiling - flat_B = ((B + 63) // 64) * 64 - flat_D = ((D + 63) // 64) * 64 - - inp_image = torch.zeros((1, 1, flat_B, flat_D), dtype=torch.bfloat16, device=image_features.device) - inp_image[0, 0, :B, :D] = image_features.to(torch.bfloat16) - - inp_text = torch.zeros((1, 1, flat_B, flat_D), dtype=torch.bfloat16, device=text_features.device) - inp_text[0, 0, :B, :D] = text_features.to(torch.bfloat16) - - n = inp_image.numel() - - # Preallocate TK tensors matching the topology PGL requires - image_tk = get_or_create_parallel_tensor(ext, (1, 1, flat_B, flat_D), torch.bfloat16, multicast=False) - text_tk = get_or_create_parallel_tensor(ext, (1, 1, flat_B, flat_D), torch.bfloat16, multicast=False) - - image_tk.data_.reshape(-1)[:n].copy_(inp_image.reshape(-1)) - text_tk.data_.reshape(-1)[:n].copy_(inp_text.reshape(-1)) - - # Grid output to prevent atomics precision loss across millions of summed logits - num_row_blocks = flat_B // 64 - num_col_blocks = flat_B // 64 - grid_size = num_row_blocks * num_col_blocks * NUM_DEVICES - - loss_tk = get_or_create_parallel_tensor(ext, (grid_size,), torch.float32, multicast=False) - loss_tk.data_.zero_() - - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - scale_val = float(logit_scale.item() if isinstance(logit_scale, torch.Tensor) else logit_scale) - bias_val = float(logit_bias.item() if isinstance(logit_bias, torch.Tensor) else logit_bias) - - # Launch cross-rank block fused kernel - ext.tk_siglip(image_tk, text_tk, loss_tk, barrier_tk, B, flat_D, scale_val, bias_val) - - # Local reduce all blocks and divide by logical local batch - total_loss = loss_tk.data_[:grid_size].sum() / B - - return total_loss \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_parallelkittens.py deleted file mode 100755 index 3ff1db1..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_parallelkittens.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -ThunderKittens scatter via pull-based TMA over peer-to-peer memory. - -Replaces NCCL scatter with a symmetric PGL kernel. The source rank loads its -full multi-chunk tensor into an IPC-mapped parallel tensor. All 8 GPUs then -execute a pull via TMA load directly from the source rank's memory into their -own local output buffers. Hardware copy engines handle the cross-device data -movement. - -Requires: ThunderKittens headers at $THUNDERKITTENS_ROOT/include. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (scatter entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace tk_scatter { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - int src_rank; - int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((output.cols() / globals::COL_BLOCK_SIZE) * - (output.rows() / globals::ROW_BLOCK_SIZE) * - output.depth() * output.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.output.depth() * (G.output.rows() / globals::ROW_BLOCK_SIZE) * (G.output.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.output.depth() * (G.output.rows() / globals::ROW_BLOCK_SIZE) * (G.output.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE * (G.output.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE * (G.output.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // Each rank pulls its designated chunk from the src_rank - // The chunk index in the src_rank tensor corresponds to G.dev_idx - tma::load_async(tile, G.input[G.src_rank], {G.dev_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - wait(arrived, 0); - // Write out locally to its own output tensor - tma::store_async(G.output[G.dev_idx], tile, {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace tk_scatter - -namespace tk_scatter_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace tk_scatter_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int src_rank -) { - kittens::py::parallel_tensor_check(output, input); - - tk_scatter::globals scatter_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .src_rank = src_rank, - .dev_idx = input.local_rank_ - }; - - tk_scatter_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronize to ensure src_rank's buffer is populated before pulls start - kittens::py::launch_kernel(barrier_G); - - kittens::py::launch_kernel(scatter_G); - - // Synchronize to ensure all ranks finish pulling before subsequent calls can overwrite src_rank's buffer - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_scatter", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_scatter_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -@torch.no_grad() -def solution(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - rank = dist.get_rank() - ext = _ensure_ext_jit() - - original_dtype = tensor.dtype - - if rank == src: - assert tensor.shape[0] == world, ( - f"First dimension ({tensor.shape[0]}) must equal world_size ({world})" - ) - chunk_shape = tensor.shape[1:] - chunk_elems = tensor[0].numel() - else: - chunk_shape = tensor.shape - chunk_elems = tensor.numel() - - r, c, padded_rest = _padded_row_col(chunk_elems) - - # Input on src has shape [world_size, 1, R, C] to accommodate all chunks - input_tk = get_or_create_parallel_tensor( - ext, (world, 1, r, c), torch.bfloat16, multicast=False - ) - # Output on every rank has shape [1, 1, R, C] - output_tk = get_or_create_parallel_tensor( - ext, (1, 1, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - if rank == src: - # Format the batched tensor for the TK buffer - flat = tensor.to(torch.bfloat16).reshape(world, -1).contiguous() - padded = torch.zeros(world, padded_rest, dtype=torch.bfloat16, device=tensor.device) - padded[:, :chunk_elems] = flat - inp_4 = padded.view(world, 1, r, c) - n = inp_4.numel() - # Copy to the IPC-mapped parallel tensor layout - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - # All ranks launch the kernel (pull from src_rank's memory) - ext.tk_scatter(output_tk, input_tk, barrier_tk, src) - - # Extract the resulting chunk logically populated on every rank - out_flat = output_tk.data_.reshape(-1)[:padded_rest] - return out_flat[:chunk_elems].contiguous().reshape(chunk_shape).to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_parallelkittens.py deleted file mode 100755 index 85c2164..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_parallelkittens.py +++ /dev/null @@ -1,356 +0,0 @@ -""" -Strategy: -1. Perform standard `torch.fft.fft` along the replicated spatial dimension (yielding a full complex spectrum). -2. Use `torch.view_as_real` and cast to bfloat16 to prepare the spectrum for device-side communication. -3. Overlap compute and data movement via a ThunderKittens TMA-based personalized all-to-all transpose. We shape the payload so `scatter=0` maps to the destination rank, and `gather=1` gathers chunks from source ranks directly into pre-aligned buffers in symmetric memory. -4. Exploit PyTorch's `movedim` and `reshape` to locally construct the contiguous block representing `torch.cat(recv_chunks, dim=dim1)` after the communication step. -5. Cast the rearranged payload back to `float32` -> `complex64`, and execute the second `torch.fft.fft` along the now-local spatial dimension. -6. Truncate to keep the half-spectrum, reproducing the exact PhysicsNeMo 2D real FFT semantics while substituting the heavy NCCL intermediate layout with symmetric TMA buffers. -""" - -import os -from typing import Optional, Sequence - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - -def _truncate(tensor: torch.Tensor, dim: int, size: int) -> torch.Tensor: - slices = [slice(None)] * tensor.ndim - slices[dim % tensor.ndim] = slice(0, size) - return tensor[tuple(slices)].contiguous() - -def all_to_all_transpose_cat( - tensor: torch.Tensor, dim0: int, dim1: int, world_size: int, ext -) -> torch.Tensor: - # tensor holds an extra dimension representing real vs imag parts of the complex value - shape = list(tensor.shape) - D0 = shape[dim0] - chunk0 = D0 // world_size - - new_shape = shape.copy() - new_shape[dim0] = world_size - new_shape.insert(dim0 + 1, chunk0) - - # Isolate split_dim blocks locally and push the World index to the very front for flattened scatter - t = tensor.reshape(new_shape).movedim(dim0, 0).contiguous() - - W = world_size - rest = t.numel() // W - r, c, padded_rest = _padded_row_col(rest) - - padded = torch.zeros(W, padded_rest, dtype=torch.bfloat16, device=tensor.device) - flat_t = t.view(W, rest) - padded[:, :rest] = flat_t - inp_4 = padded.view(W, 1, r, c) - - input_tk = get_or_create_parallel_tensor( - ext, (W, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, W, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=W) - - n = inp_4.numel() - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - # Device-side multi-cast/transpose exploiting symmetric layouts - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - # Strip padding - out_flat = output_tk.data_.reshape(-1)[:n].view(1, W, r, c)[0].reshape(W, padded_rest)[:, :rest].contiguous() - - send_shape = shape.copy() - send_shape[dim0] = chunk0 - - # Layout perfectly matches concatenated recv elements across the rank group - out = out_flat.view(W, *send_shape) - out = out.movedim(0, dim1) - - final_shape = send_shape.copy() - final_shape[dim1] = final_shape[dim1] * W - out = out.reshape(final_shape) - - return out - -@torch.no_grad() -def solution( - x: torch.Tensor, - s: Sequence[int], - dim: Sequence[int], - norm: str = "ortho", - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world = dist.get_world_size(group) - - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - # Normalize dimensions exactly to positional bounds for slicing - dim0, dim1 = int(dim[0]) % x.ndim, int(dim[1]) % x.ndim - - ext = _ensure_ext_jit() - - # 1. Transform the replicated spatial dimension -> produces full complex spectrum. - x1 = torch.fft.fft(x, n=int(s[0]), dim=dim0, norm=norm) - - # Convert complex domain to half-float tensor matching the TK expected footprint bounds - x1_real = torch.view_as_real(x1).to(torch.bfloat16) - - # 2. Transpose (ParallelKittens all-to-all switching domains from dim1 -> dim0 chunks) - x1_tran_real_bf16 = all_to_all_transpose_cat(x1_real, dim0, dim1, world, ext) - - # Return elements to precision expected for the secondary complex operation map - x1_tran = torch.view_as_complex(x1_tran_real_bf16.to(torch.float32)) - - # 3. Perform second transformation over the newly localized dimensional payload. - x2 = torch.fft.fft(x1_tran, n=int(s[1]), dim=dim1, norm=norm) - - # 4. Truncate returning real-input constraints over the newly calculated subset. - return _truncate(x2, dim1, x2.shape[dim1] // 2 + 1) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_parallelkittens.py deleted file mode 100755 index 68d4eab..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_parallelkittens.py +++ /dev/null @@ -1,473 +0,0 @@ -""" -Strategy: -1. Communication Volume Reduction: Modified `_conj_pad_2d` to extract, all-gather, flip, and scatter *only* the padded region, slicing the required communication bandwidth in half compared to the reference's full-tensor NCCL collective. -2. Device-side Comm Patterns: Swapped `torch.distributed` collectives (`all_to_all` and `all_gather`) for custom ThunderKittens Hopper TMA kernels. These kernels pull/push tiles across W=8 peers directly over NVLink using explicit PGL addressing and symmetric memory, eliminating CPU host overhead. -3. Lossless BF16 Transfer: Since the specified kernels are compiled for `bf16` arrays but the FFT uses `complex64`, we bitcast the data structures via contiguous `torch.view` down to `bfloat16` prior to P2P transport. This attains max NVLink throughput while recovering exact IEEE identical floats on arrival. -""" - -import os -from typing import Optional, Sequence - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded ThunderKittens CUDA Source: All-to-all and All-gather via TMA -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -// ============================================================================ -// ALL-TO-ALL -// ============================================================================ -namespace all_to_all { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} -} // namespace all_to_all - - -// ============================================================================ -// ALL-GATHER -// ============================================================================ -namespace all_gather { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; // [W, 1, R, C] - parallel_layout input; // [1, 1, R, C] - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int depth_idx = 0; // Fixed because input depth is 1 - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx % (G.input.cols() / globals::COL_BLOCK_SIZE); - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - - // Pull from every device and store to the batch index matching the source - for (int src = 0; src < globals::NUM_DEVICES; src++) { - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[src], {0, depth_idx, row_block_idx, col_block_idx}, arrived); - wait(arrived, src); - - tma::store_async(G.output[G.dev_idx], tile, {src, depth_idx, row_block_idx, col_block_idx}); - tma::store_commit_group(); - tma::store_async_wait(); - } -} -} // namespace all_gather - - -// ============================================================================ -// BARRIER -// ============================================================================ -namespace ext_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace ext_barrier - - -void entrypoint_all_to_all( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - ext_barrier::globals b_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(b_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(G); - else - TORCH_CHECK(false, "Unsupported axes"); - - kittens::py::launch_kernel(b_G); -} - -void entrypoint_all_gather( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, input); - - all_gather::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - ext_barrier::globals b_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(b_G); - kittens::py::launch_kernel(G); - kittens::py::launch_kernel(b_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint_all_to_all); - m.def("tk_all_gather", &entrypoint_all_gather); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_fft_comms_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -def tk_all_to_all_call(tensor: torch.Tensor, ext, barrier_tk, scatter_axis=0, gather_axis=1) -> torch.Tensor: - """Wrapper mapping arbitrary tensors perfectly to TK WxWxRxC bfloat16 TMA layouts via bitcasting.""" - w = dist.get_world_size() - - # Exact bitcast to bfloat16 view; avoids losing fp32/complex64 precision required for FFT correctness - if tensor.is_complex(): - t_bits = torch.view_as_real(tensor).flatten().contiguous().view(torch.bfloat16) - else: - t_bits = tensor.flatten().contiguous().view(torch.bfloat16) - - rest = t_bits.numel() // w - r, c, padded_rest = _padded_row_col(rest) - - padded = torch.zeros(w, padded_rest, dtype=torch.bfloat16, device=tensor.device) - padded[:, :rest] = t_bits.view(w, rest) - inp_4 = padded.view(w, 1, r, c) - - input_tk = get_or_create_parallel_tensor(ext, (w, 1, r, c), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (1, w, r, c), torch.bfloat16, multicast=False) - - input_tk.data_.reshape(-1)[:inp_4.numel()].copy_(inp_4.reshape(-1)) - - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, scatter_axis, gather_axis) - - out_flat = output_tk.data_.reshape(1, w, r, c)[0].reshape(w, padded_rest)[:, :rest].contiguous() - - if tensor.is_complex(): - out_f32 = out_flat.view(torch.float32).view(w, *tensor.shape[1:], 2) - return torch.view_as_complex(out_f32) - else: - return out_flat.view(tensor.dtype).view(w, *tensor.shape[1:]) - - -def tk_all_gather_call(tensor: torch.Tensor, ext, barrier_tk) -> torch.Tensor: - """TMA all-gather gathering 1/W blocks directly over NVLink with bitcasted BF16 layout.""" - w = dist.get_world_size() - - if tensor.is_complex(): - t_bits = torch.view_as_real(tensor).flatten().contiguous().view(torch.bfloat16) - else: - t_bits = tensor.flatten().contiguous().view(torch.bfloat16) - - rest = t_bits.numel() - r, c, padded_rest = _padded_row_col(rest) - - padded = torch.zeros(1, padded_rest, dtype=torch.bfloat16, device=tensor.device) - padded[0, :rest] = t_bits - inp_4 = padded.view(1, 1, r, c) - - input_tk = get_or_create_parallel_tensor(ext, (1, 1, r, c), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (w, 1, r, c), torch.bfloat16, multicast=False) - - input_tk.data_.reshape(-1)[:inp_4.numel()].copy_(inp_4.reshape(-1)) - - ext.tk_all_gather(output_tk, input_tk, barrier_tk) - - out_flat = output_tk.data_.reshape(w, padded_rest)[:, :rest].contiguous() - - if tensor.is_complex(): - out_f32 = out_flat.view(torch.float32).view(w, *tensor.shape, 2) - return torch.view_as_complex(out_f32) - else: - return out_flat.view(tensor.dtype).view(w, *tensor.shape) - - -def _pad_zero(tensor: torch.Tensor, dim: int, size: int) -> torch.Tensor: - """Zero-pad tensor along dim to size.""" - dim = dim % tensor.ndim - pad = [0] * (2 * (tensor.ndim - dim)) - pad[1] = size - tensor.shape[dim] - return F.pad(tensor, pad, mode="constant", value=0.0) - - -def _conj_pad_2d_tk( - tensor: torch.Tensor, - pad_dim: int, - other_dim: int, - size: int, - ext, - barrier_tk, -) -> torch.Tensor: - """Pad the RFFT half spectrum natively exchanging *only* the padded region via TMA.""" - pad_dim = pad_dim % tensor.ndim - other_dim = other_dim % tensor.ndim - orig_size = tensor.shape[pad_dim] - - # 1. Pad zeroes natively - tensor_pad = _pad_zero(tensor, pad_dim, size) - - # 2. Local filling via complex conjugate map for the resident chunk - lhs_slice = [slice(0, s) for s in tensor.shape] - lhs_slice[pad_dim] = slice(orig_size, size) - rhs_slice = [slice(0, s) for s in tensor.shape] - rhs_slice[pad_dim] = slice(1, size - orig_size + 1) - tensor_pad[tuple(lhs_slice)] = torch.flip(torch.conj(tensor_pad[tuple(rhs_slice)]), dims=[pad_dim]) - - # 3. Only gather the small padded portion (~half communication footprint vs reference) - local_pad_region = tensor_pad[tuple(lhs_slice)].contiguous() - gathered_pad = tk_all_gather_call(local_pad_region, ext, barrier_tk) - full_pad_region = torch.cat(list(gathered_pad), dim=other_dim) - - # 4. Flip the full pad chunk dimension symmetrically across ranks - full_pad_dim_size = full_pad_region.shape[other_dim] - flip_slice = [slice(0, s) for s in full_pad_region.shape] - flip_slice[other_dim] = slice(1, full_pad_dim_size) - full_pad_region[tuple(flip_slice)] = torch.flip(full_pad_region[tuple(flip_slice)], dims=[other_dim]) - - # 5. Extract my chunk exclusively - rank = dist.get_rank() - my_chunk_size = full_pad_dim_size // dist.get_world_size() - my_pad_slice = [slice(0, s) for s in full_pad_region.shape] - my_pad_slice[other_dim] = slice(rank * my_chunk_size, (rank + 1) * my_chunk_size) - - tensor_pad[tuple(lhs_slice)] = full_pad_region[tuple(my_pad_slice)] - return tensor_pad - - -@torch.no_grad() -def solution( - x: torch.Tensor, - s: Optional[Sequence[int]], - dim: Sequence[int], - norm: str = "ortho", - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - Optimized PhysicsNeMo-style distributed 2D inverse real FFT. - """ - ext = _ensure_ext_jit() - world_size = dist.get_world_size(group) if group is not None else dist.get_world_size() - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - dim0, dim1 = int(dim[0]), int(dim[1]) - if s is not None: - first_dim_size = int(s[0]) - last_dim_size = int(s[1]) - else: - first_dim_size = int(x.shape[dim0]) - last_dim_size = int(2 * (x.shape[dim1] - 1)) - - # 1. Half-bandwidth conjugate rebuild - x_pad = _conj_pad_2d_tk(x, pad_dim=dim1, other_dim=dim0, size=last_dim_size, ext=ext, barrier_tk=barrier_tk) - - # 2. Transform the replicated second dimension - x1 = torch.fft.ifft(x_pad, n=last_dim_size, dim=dim1, norm=norm) - - # 3. Fast device-side TMA NVLink Transpose - chunk_size = x1.shape[dim1] // world_size - send_chunks = torch.stack(list(torch.split(x1, chunk_size, dim=dim1)), dim=0).contiguous() - recv_chunks = tk_all_to_all_call(send_chunks, ext, barrier_tk) - x1_tran = torch.cat(list(recv_chunks), dim=dim0) - - # 4. Final transform mapping correctly reconstructed spatial sharding - x2 = torch.fft.ifft(x1_tran, n=first_dim_size, dim=dim0, norm=norm) - return torch.real(x2).contiguous() \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_parallelkittens.py deleted file mode 100755 index 42cbeed..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_parallelkittens.py +++ /dev/null @@ -1,460 +0,0 @@ -import os -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded ThunderKittens C++ / CUDA Source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; -}; - -using pgl_i32 = pgl, 8, false>; - -struct globals { - const float* means; - const float* quats; - const float* scales; - const float* opacities; - const float* colors; - - const float* viewmats; - const float* Ks; - - const int* c_world_offsets; - const int* n_world_offsets; - int* local_send_counts; - - int N; - int C_total; - int D; - int MAX_CAPACITY; - int WORDS; - - int image_width; - int image_height; - float eps2d; - float near_plane; - float far_plane; - int dev_idx; - int num_devices; - - pgl_i32 p2p_buffer; - - __host__ inline dim3 grid() const { - return dim3((N + config::NUM_THREADS - 1) / config::NUM_THREADS > 0 ? - (N + config::NUM_THREADS - 1) / config::NUM_THREADS : 1); - } -}; - -__device__ inline void kernel(const globals &G) { - int n = blockIdx.x * blockDim.x + threadIdx.x; - if (n >= G.N) return; - - float mean[3] = { G.means[n*3], G.means[n*3+1], G.means[n*3+2] }; - float quat[4] = { G.quats[n*4], G.quats[n*4+1], G.quats[n*4+2], G.quats[n*4+3] }; - float scale[3] = { G.scales[n*3], G.scales[n*3+1], G.scales[n*3+2] }; - float opacity = G.opacities[n]; - - float q_len = sqrtf(quat[0]*quat[0] + quat[1]*quat[1] + quat[2]*quat[2] + quat[3]*quat[3]); - float w = quat[0]/q_len, x = quat[1]/q_len, y = quat[2]/q_len, z = quat[3]/q_len; - - float R[3][3] = { - {1.0f - 2.0f*(y*y + z*z), 2.0f*(x*y - w*z), 2.0f*(x*z + w*y)}, - {2.0f*(x*y + w*z), 1.0f - 2.0f*(x*x + z*z), 2.0f*(y*z - w*x)}, - {2.0f*(x*z - w*y), 2.0f*(y*z + w*x), 1.0f - 2.0f*(x*x + y*y)} - }; - - float M[3][3]; - for(int i=0; i<3; ++i) - for(int j=0; j<3; ++j) - M[i][j] = R[i][j] * scale[j]; - - float cov3d[3][3] = {0}; - for(int i=0; i<3; ++i) - for(int j=0; j<3; ++j) - for(int k=0; k<3; ++k) - cov3d[i][k] += M[i][j] * M[k][j]; - - for(int c=0; c= G.far_plane) continue; - - float cov_c[3][3] = {0}; - for(int i=0; i<3; ++i) - for(int j=0; j<3; ++j) - for(int k=0; k<3; ++k) - for(int l=0; l<3; ++l) - cov_c[i][j] += view[i][k] * cov3d[k][l] * view[j][l]; - - float fx = G.Ks[c*9 + 0]; - float cx = G.Ks[c*9 + 2]; - float fy = G.Ks[c*9 + 4]; - float cy = G.Ks[c*9 + 5]; - - float tx = mean_c[0], ty = mean_c[1], tz = mean_c[2]; - float tz2 = tz * tz; - - float tan_fovx = 0.5f * G.image_width / fx; - float tan_fovy = 0.5f * G.image_height / fy; - float lim_x_pos = (G.image_width - cx) / fx + 0.3f * tan_fovx; - float lim_x_neg = cx / fx + 0.3f * tan_fovx; - float lim_y_pos = (G.image_height - cy) / fy + 0.3f * tan_fovy; - float lim_y_neg = cy / fy + 0.3f * tan_fovy; - - float cl_tx = tz * fmaxf(-lim_x_neg, fminf(tx/tz, lim_x_pos)); - float cl_ty = tz * fmaxf(-lim_y_neg, fminf(ty/tz, lim_y_pos)); - - float J[2][3] = { - {fx / tz, 0.0f, -fx * cl_tx / tz2}, - {0.0f, fy / tz, -fy * cl_ty / tz2} - }; - - float cov2d[2][2] = {0}; - for(int i=0; i<2; ++i) - for(int j=0; j<2; ++j) - for(int k=0; k<3; ++k) - for(int l=0; l<3; ++l) - cov2d[i][j] += J[i][k] * cov_c[k][l] * J[j][l]; - - cov2d[0][0] += G.eps2d; - cov2d[1][1] += G.eps2d; - - float det = cov2d[0][0] * cov2d[1][1] - cov2d[0][1] * cov2d[1][0]; - if (det < 1e-10f) det = 1e-10f; - - float conic_x = cov2d[1][1] / det; - float conic_y = -(cov2d[0][1] + cov2d[1][0]) / 2.0f / det; - float conic_z = cov2d[0][0] / det; - - float radius_x = ceilf(3.33f * sqrtf(cov2d[0][0])); - float radius_y = ceilf(3.33f * sqrtf(cov2d[1][1])); - - float mean2d_x = fx * tx / tz + cx; - float mean2d_y = fy * ty / tz + cy; - - bool inside = (mean2d_x + radius_x > 0.0f) && (mean2d_x - radius_x < G.image_width) && - (mean2d_y + radius_y > 0.0f) && (mean2d_y - radius_y < G.image_height); - - if (inside && radius_x > 0.0f && radius_y > 0.0f) { - int dest_rank = 0; - while (dest_rank < G.num_devices - 1 && c >= G.c_world_offsets[dest_rank + 1]) { - dest_rank++; - } - - int idx = atomicAdd(&G.local_send_counts[dest_rank], 1); - if (idx < G.MAX_CAPACITY) { - int offset = (G.dev_idx * G.MAX_CAPACITY + idx) * G.WORDS; - // Direct NVLink Peer-To-Peer Write inside TK's symmetric memory layout - int32_t* base_ptr = (int32_t*)&G.p2p_buffer[dest_rank]({0,0,0,0}); - int32_t* ptr = base_ptr + offset; - - ptr[0] = c - G.c_world_offsets[dest_rank]; // Local camera id for destination - ptr[1] = n + G.n_world_offsets[G.dev_idx]; // Global gaussian id - ptr[2] = (int32_t)radius_x; - ptr[3] = (int32_t)radius_y; - - nv_bfloat16* fptr = (nv_bfloat16*)(ptr + 4); - fptr[0] = __float2bfloat16(mean2d_x); - fptr[1] = __float2bfloat16(mean2d_y); - fptr[2] = __float2bfloat16(depth); - fptr[3] = __float2bfloat16(conic_x); - fptr[4] = __float2bfloat16(conic_y); - fptr[5] = __float2bfloat16(conic_z); - fptr[6] = __float2bfloat16(opacity); - for(int d=0; d bar; - int dev_idx; - }; - __device__ inline void kernel(const globals &G) { - barrier_all(G.bar, {0}, G.dev_idx); - } -} - -void entrypoint( - torch::Tensor means, torch::Tensor quats, torch::Tensor scales, - torch::Tensor opacities, torch::Tensor colors, - torch::Tensor viewmats, torch::Tensor Ks, - torch::Tensor c_world_offsets, torch::Tensor n_world_offsets, - torch::Tensor local_send_counts, - int N, int C_total, int D, int MAX_CAPACITY, int WORDS, - int image_width, int image_height, float eps2d, float near_plane, float far_plane, - kittens::py::TKParallelTensor &p2p_buffer, - kittens::py::TKParallelTensor &barrier -) { - globals G; - G.means = means.data_ptr(); - G.quats = quats.data_ptr(); - G.scales = scales.data_ptr(); - G.opacities = opacities.data_ptr(); - G.colors = colors.data_ptr(); - - G.viewmats = viewmats.data_ptr(); - G.Ks = Ks.data_ptr(); - - G.c_world_offsets = c_world_offsets.data_ptr(); - G.n_world_offsets = n_world_offsets.data_ptr(); - G.local_send_counts = local_send_counts.data_ptr(); - - G.N = N; - G.C_total = C_total; - G.D = D; - G.MAX_CAPACITY = MAX_CAPACITY; - G.WORDS = WORDS; - - G.image_width = image_width; - G.image_height = image_height; - G.eps2d = eps2d; - G.near_plane = near_plane; - G.far_plane = far_plane; - - G.dev_idx = p2p_buffer.local_rank_; - G.num_devices = 8; - - G.p2p_buffer = kittens::py::parallel_tensor_to_pgl(p2p_buffer); - - barrier_ns::globals bG { - .bar = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Barrier setup - kittens::py::launch_kernel(bG); - - int num_blocks = (N + config::NUM_THREADS - 1) / config::NUM_THREADS; - if (num_blocks > 0) { - kittens::py::launch_kernel(G, {num_blocks, 1, 1}); - } - - // Barrier synchronization across stream completions - kittens::py::launch_kernel(bG); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_proj", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", "--use_fast_math", "--expt-extended-lambda", "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", "-Xcompiler=-fno-strict-aliasing", "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gsplat_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - means: Tensor, quats: Tensor, scales: Tensor, opacities: Tensor, colors: Tensor, - viewmats: Tensor, Ks: Tensor, - image_width: int, image_height: int, eps2d: float = 0.3, - near_plane: float = 0.01, far_plane: float = 1e10, camera_model: str = "pinhole" -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - - assert dist.is_initialized() - assert camera_model == "pinhole" - - ext = _ensure_ext_jit() - world_size = dist.get_world_size() - world_rank = dist.get_rank() - device = means.device - original_dtype = means.dtype - - # Assert exactly 8 peer setup - assert world_size == 8, f"ThunderKittens layout expects NUM_DEVICES=8, got {world_size}" - - # Enforce input memory layouts for CUDA kernel - means = means.contiguous().to(torch.float32) - quats = quats.contiguous().to(torch.float32) - scales = scales.contiguous().to(torch.float32) - opacities = opacities.contiguous().to(torch.float32) - colors = colors.contiguous().to(torch.float32) - - N_local = means.shape[0] - C_local = viewmats.shape[0] - D = colors.shape[1] - - # Gather dataset scale - N_world_tensor = torch.zeros(world_size, dtype=torch.int32, device=device) - dist.all_gather_into_tensor(N_world_tensor, torch.tensor([N_local], dtype=torch.int32, device=device)) - N_world = N_world_tensor.tolist() - C_world = [C_local] * world_size - - # Only gather camera metrics implicitly (small memory footprint, completely overlap dense projection math) - viewmats_gather = [torch.empty_like(viewmats) for _ in range(world_size)] - dist.all_gather(viewmats_gather, viewmats) - viewmats_all = torch.cat(viewmats_gather, dim=0).contiguous().to(torch.float32) - - Ks_gather = [torch.empty_like(Ks) for _ in range(world_size)] - dist.all_gather(Ks_gather, Ks) - Ks_all = torch.cat(Ks_gather, dim=0).contiguous().to(torch.float32) - - c_offsets = [0] + torch.cumsum(torch.tensor(C_world, dtype=torch.int32), dim=0).tolist() - n_offsets = [0] + torch.cumsum(torch.tensor(N_world, dtype=torch.int32), dim=0).tolist() - - # Pre-calculated structural layout for fused bytes: `[4 ints (IDs/Radii)] + [(7 + D)/2 ints (bfloat16 math params)]` - WORDS_PER_PROJ = 4 + (7 + D + 1) // 2 - MAX_CAPACITY = min(sum(C_world) * max(N_world) // 4 + 10000, 2000000) - shape = (1, 1, 1, world_size * MAX_CAPACITY * WORDS_PER_PROJ) - - # Establish contiguous buffer blocks mapped inside ThunderKittens symmetry - p2p_buffer = get_or_create_parallel_tensor(ext, shape, torch.int32, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - local_send_counts = torch.zeros(world_size, dtype=torch.int32, device=device) - - # Fire Fused ThunderKittens Projection & Peer Write kernel - ext.tk_fused_proj( - means, quats, scales, opacities, colors, - viewmats_all, Ks_all, - torch.tensor(c_offsets, dtype=torch.int32, device=device), - torch.tensor(n_offsets, dtype=torch.int32, device=device), - local_send_counts, - N_local, sum(C_world), D, MAX_CAPACITY, WORDS_PER_PROJ, - image_width, image_height, eps2d, near_plane, far_plane, - p2p_buffer, barrier_tk - ) - - all_send_counts = torch.zeros(world_size, world_size, dtype=torch.int32, device=device) - dist.all_gather_into_tensor(all_send_counts, local_send_counts) - my_recv_counts = all_send_counts[:, world_rank] - - # Process written Peer structures inside our slots - out_cam_ids, out_gauss_ids, out_radii, out_means2d = [], [], [], [] - out_depths, out_conics, out_opacities, out_colors = [], [], [], [] - - p2p_data = p2p_buffer.data_.view(world_size, MAX_CAPACITY, WORDS_PER_PROJ) - - for i in range(world_size): - count = my_recv_counts[i].item() - if count == 0: continue - - chunk = p2p_data[i, :count] - - cam_ids = chunk[:, 0].contiguous() - gauss_ids = chunk[:, 1].contiguous() - radii = torch.stack([chunk[:, 2], chunk[:, 3]], dim=-1) - - bf16_data = chunk[:, 4:].contiguous().view(torch.bfloat16) - means2d = bf16_data[:, 0:2].to(original_dtype) - depths = bf16_data[:, 2].to(original_dtype) - conics = bf16_data[:, 3:6].to(original_dtype) - opacities = bf16_data[:, 6].to(original_dtype) - out_col = bf16_data[:, 7:7+D].to(original_dtype) - - # Enforce exactly the identical sort schema generated by `.where` then `.cat` from PyTorch - sort_keys = cam_ids.long() * n_offsets[-1] + gauss_ids.long() - sort_idx = torch.argsort(sort_keys) - - out_cam_ids.append(cam_ids[sort_idx]) - out_gauss_ids.append(gauss_ids[sort_idx]) - out_radii.append(radii[sort_idx]) - out_means2d.append(means2d[sort_idx]) - out_depths.append(depths[sort_idx]) - out_conics.append(conics[sort_idx]) - out_opacities.append(opacities[sort_idx]) - out_colors.append(out_col[sort_idx]) - - if len(out_cam_ids) == 0: - return ( - torch.empty(0, dtype=torch.int32, device=device), - torch.empty(0, dtype=torch.int32, device=device), - torch.empty((0, 2), dtype=torch.int32, device=device), - torch.empty((0, 2), dtype=original_dtype, device=device), - torch.empty(0, dtype=original_dtype, device=device), - torch.empty((0, 3), dtype=original_dtype, device=device), - torch.empty(0, dtype=original_dtype, device=device), - torch.empty((0, D), dtype=original_dtype, device=device), - ) - - return ( - torch.cat(out_cam_ids), - torch.cat(out_gauss_ids), - torch.cat(out_radii), - torch.cat(out_means2d), - torch.cat(out_depths), - torch.cat(out_conics), - torch.cat(out_opacities), - torch.cat(out_colors) - ) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_parallelkittens.py deleted file mode 100755 index 1325138..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_parallelkittens.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -Distributed DISCO spherical convolution forward. - -Optimized with ParallelKittens / ThunderKittens CUDA integrations and -overlapping reduction schedules. -""" - -import os -from typing import List, Optional - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: ThunderKittens Subgroup All-to-All via TMA -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace all_to_all_subgroup { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = NUM_DEVICES_PLACEHOLDER; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - int group_ranks[NUM_DEVICES]; - int my_subgroup_idx; - int subgroup_size; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / COL_BLOCK_SIZE) * - (input.rows() / ROW_BLOCK_SIZE) * - input.depth() * subgroup_size); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int col_blocks = G.input.cols() / globals::COL_BLOCK_SIZE; - int row_blocks = G.input.rows() / globals::ROW_BLOCK_SIZE; - int depth = G.input.depth(); - - int scatter_idx = task_idx / (depth * row_blocks * col_blocks); - task_idx %= (depth * row_blocks * col_blocks); - - int depth_idx = task_idx / (row_blocks * col_blocks); - task_idx %= (row_blocks * col_blocks); - - int row_block_idx = task_idx / col_blocks; - int col_block_idx = task_idx % col_blocks; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // Load from my physical memory, representing the scatter_idx chunk intended for peer - tma::load_async(tile, G.input[G.group_ranks[G.my_subgroup_idx]], {scatter_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx = G.group_ranks[scatter_idx]; - int gather_idx = G.my_subgroup_idx; - - wait(arrived, 0); - // Write TMA scatter to peer's layout index allocated for me - tma::store_async(G.output[dst_dev_idx], tile, {gather_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all_subgroup - -namespace all_to_all_barrier { -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; -struct globals { - static constexpr int NUM_DEVICES = NUM_DEVICES_PLACEHOLDER; - barrier_t barrier; - const int dev_idx; -}; -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - std::vector group_ranks, - int my_subgroup_idx -) { - kittens::py::parallel_tensor_check(output, input); - - all_to_all_subgroup::globals G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .my_subgroup_idx = my_subgroup_idx, - .subgroup_size = static_cast(group_ranks.size()) - }; - for(size_t i = 0; i < group_ranks.size(); ++i) { - G.group_ranks[i] = group_ranks[i]; - } - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_subgroup_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - world_size = dist.get_world_size() if dist.is_initialized() else 8 - src = CUDA_SRC.replace("NUM_DEVICES_PLACEHOLDER", str(world_size)) - - _ext = compile_cuda_extension( - "tk_disco_s2_ext", - src, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _tk_all_to_all_subgroup( - tensor: torch.Tensor, - group: dist.ProcessGroup, - ext, -) -> torch.Tensor: - """ - Subgroup scatter (dim=0 after permute) and gather (dim=-1) orchestrating TK TMA kernels. - Handles padding transparently. - """ - group_ranks = dist.get_process_group_ranks(group) - N = len(group_ranks) - my_rank = dist.get_rank() - my_subgroup_idx = group_ranks.index(my_rank) - world_size = dist.get_world_size() - - B, C, R, C_inner = tensor.shape - C_chunk = C // N - - # Extract to uniform shape [N, Depth, Rows, Cols] mapping - x = tensor.view(B, N, C_chunk, R, C_inner).permute(1, 0, 2, 3, 4).reshape(N, B * C_chunk, R, C_inner).contiguous() - - pad_R = (16 - (R % 16)) % 16 - pad_C = (128 - (C_inner % 128)) % 128 - if pad_R > 0 or pad_C > 0: - x_pad = torch.nn.functional.pad(x, (0, pad_C, 0, pad_R)) - else: - x_pad = x - - R_pad = R + pad_R - C_pad = C_inner + pad_C - - input_tk = get_or_create_parallel_tensor(ext, (N, B * C_chunk, R_pad, C_pad), torch.bfloat16, multicast=False) - output_tk = get_or_create_parallel_tensor(ext, (N, B * C_chunk, R_pad, C_pad), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - input_tk.data_.view(-1)[:x_pad.numel()].copy_(x_pad.flatten()) - - ext.tk_subgroup_all_to_all(output_tk, input_tk, barrier_tk, group_ranks, my_subgroup_idx) - - y_pad = output_tk.data_.view(N, B * C_chunk, R_pad, C_pad) - if pad_R > 0 or pad_C > 0: - y = y_pad[:, :, :R, :C_inner] - else: - y = y_pad - - # Layout un-permute back from [N, B, C_chunk, R, C_inner] to [B, C_chunk, R, N * C_inner] - y = y.view(N, B, C_chunk, R, C_inner).permute(1, 2, 3, 0, 4).reshape(B, C_chunk, R, N * C_inner).contiguous() - return y - - -@torch.no_grad() -def solution( - x: torch.Tensor, - psi: torch.Tensor, - weight: torch.Tensor, - groups: int, - nlon_out: int, - nlon_in: int, - azimuth_group: Optional[dist.ProcessGroup] = None, - polar_group: Optional[dist.ProcessGroup] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - azimuth_group = azimuth_group or dist.group.WORLD - polar_group = polar_group or dist.group.WORLD - azimuth_size = dist.get_world_size(group=azimuth_group) - polar_size = dist.get_world_size(group=polar_group) - - ext = _ensure_ext_jit() - - # 1. Forward Transpose (Azimuth). Maps cleanly to the TK TMA scattered logic. - if azimuth_size > 1: - x = _tk_all_to_all_subgroup(x.to(torch.bfloat16), azimuth_group, ext) - - # 2 & 3. Ovelapped DISCO Math with Polar Reduce-Scatter Collectives - batch_size, n_chans, nlat_in, nlon_in_cur = x.shape - kernel_size, nlat_out, _ = psi.shape - pscale = nlon_in_cur // nlon_out - - B_C = batch_size * n_chans - x_flat = x.view(B_C, nlat_in, nlon_in_cur).permute(1, 2, 0).to(torch.bfloat16) - - nlat_out_local = nlat_out // polar_size if polar_size > 1 else nlat_out - - # Pre-allocate full loop space to host asynchronous reduction blocks - y_local = torch.empty( - nlon_out, kernel_size, nlat_out_local, B_C, - device=x.device, dtype=torch.float32 - ) - - reqs = [] - psi_bf = psi.to(torch.bfloat16) - - for pout in range(nlon_out): - # Implicitly expand memory without copying allocations - x_exp = x_flat.reshape(1, nlat_in * nlon_in_cur, B_C).expand(kernel_size, -1, -1) - curr_y = torch.bmm(psi_bf, x_exp) - - if pout < nlon_out - 1: - x_flat = torch.roll(x_flat, -pscale, dims=1) - - # Hide the communication: chunk reduce_scatter immediately on computation boundary - if polar_size > 1: - curr_y_float = curr_y.float().contiguous() - curr_y_chunks = list(torch.split(curr_y_float, nlat_out_local, dim=1)) - req = dist.reduce_scatter( - y_local[pout], curr_y_chunks, group=polar_group, async_op=True - ) - reqs.append(req) - else: - y_local[pout].copy_(curr_y) - - if polar_size > 1: - for req in reqs: - req.wait() - - x = y_local.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out_local, nlon_out).to(torch.bfloat16) - - # 5. Backward Transpose (Azimuth). Recycles the exact same TK bidirectional layout permutation kernel. - if azimuth_size > 1: - K = x.shape[2] - x_perm = x.view(batch_size, n_chans, K * nlat_out_local, nlon_out).permute(0, 3, 2, 1).contiguous() - y_perm = _tk_all_to_all_subgroup(x_perm, azimuth_group, ext) - x = y_perm.permute(0, 3, 2, 1).view(batch_size, azimuth_size * n_chans, K, nlat_out_local, nlon_out // azimuth_size).contiguous() - - # 6. Grouped channel mixing explicit bmm (instead of unoptimizable nested tensor einsums). - B, C, K, H, W = x.shape - C_out = weight.shape[0] - groupsize = C // groups - C_out_group = C_out // groups - - w_bmm = weight.view(groups, C_out_group, groupsize * K).to(torch.bfloat16) - x_bmm = x.view(B, groups, groupsize * K, H * W).permute(1, 2, 0, 3).reshape(groups, groupsize * K, B * H * W).to(torch.bfloat16) - - out = torch.bmm(w_bmm, x_bmm) - out = out.view(groups, C_out_group, B, H, W).permute(2, 0, 1, 3, 4).reshape(B, C_out, H, W) - - # 7. Bias - if bias is not None: - out = out + bias.view(1, -1, 1, 1).to(torch.bfloat16) - - return out \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_parallelkittens.py deleted file mode 100755 index f7cbe2d..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_parallelkittens.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -Optimized DeepMD blockwise local Kalman-filter optimizer update using ThunderKittens. - -Uses ThunderKittens PGL and NVSwitch multimem multicast arrays to perform ultra-low -latency device-side All-Reduce for the Kalman gain scalar and All-Gather for the -updated parameter weights. Replaces dense allocations with in-place rank-1 updates. -""" - -import os -from typing import List, Tuple - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (TK Reduce + TK Gather + Barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -// Common barrier kernel for synchronization -namespace all_reduce_barrier { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; - }; - __device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); - } -} - -// Scalar All-Reduce (SUM) using ld_reduce -namespace all_reduce { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } - }; - - __device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); - } -} - -// Padded array All-Gather using multimem.st (broadcast) -namespace all_gather { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - parallel_layout tensor; - const int dev_idx; - const int chunk_size; - - __host__ inline dim3 grid() const { - return dim3((chunk_size + NUM_ELEMS_PER_BLOCK - 1) / NUM_ELEMS_PER_BLOCK); - } - }; - - __device__ inline void kernel(const globals &G) { - const size_t idx = globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + globals::NUM_ELEMS_PER_INST * threadIdx.x; - if (idx < G.chunk_size) { - size_t offset = G.dev_idx * G.chunk_size + idx; - // Load local rank's segment - bf16_2 tmp = *reinterpret_cast(&G.tensor.data_[offset]); - // Multicast to all ranks - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[offset]), tmp); - } - } -} - -void entrypoint_reduce( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - all_reduce::globals reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(reduce_G); - kittens::py::launch_kernel(barrier_G); -} - -void entrypoint_gather( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier, - int chunk_size -) { - kittens::py::parallel_tensor_check(tensor, barrier); - TORCH_CHECK(chunk_size % all_gather::globals::NUM_ELEMS_PER_INST == 0, "chunk_size must be even"); - - all_gather::globals gather_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_, - .chunk_size = chunk_size - }; - - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(gather_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_reduce", &entrypoint_reduce); - m.def("tk_all_gather", &entrypoint_gather); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -# Topology cache to elide all_gather_object during weights reconstruction. -_cache = {} - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_deepmd_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - if dist.is_initialized(): - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - H: List[torch.Tensor], - error: torch.Tensor, - weights: List[torch.Tensor], - P: List[torch.Tensor], - kalman_lambda: float, - kalman_nue: float = 0.9987, -) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: - - device = weights[0].device - dtype = weights[0].dtype - weights_num = len(weights) - - lam = torch.tensor(kalman_lambda, dtype=dtype, device=device) - err = error.to(device=device, dtype=dtype) - - # 1. Fast blockwise precomputations (avoids memory allocation where possible). - K_list = [] - hk_sum = torch.zeros((), dtype=dtype, device=device) - - for i in range(weights_num): - k_i = torch.mm(P[i], H[i]) - K_list.append(k_i) - # Vector dot is natively fast, avoids full matrix multiplication paths. - hk_sum += torch.vdot(H[i].squeeze(1), k_i.squeeze(1)) - - tmp_local = lam * weights_num + hk_sum - - if dist.is_initialized(): - world_size = dist.get_world_size() - assert world_size == 8, "This ThunderKittens kernel is built for NUM_DEVICES=8" - - ext = _ensure_ext_jit() - - # All-Reduce: sum denominator scalar via TK NVSwitch PGL - ALIGNMENT = 8 * 512 - reduce_tk = get_or_create_parallel_tensor(ext, (ALIGNMENT,), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - reduce_tk.data_[0] = tmp_local.view(-1)[0].to(torch.bfloat16) - if ALIGNMENT > 1: - reduce_tk.data_[1:].zero_() - - ext.tk_all_reduce(reduce_tk, barrier_tk) - tmp_global = reduce_tk.data_[0].to(dtype) - else: - tmp_global = tmp_local - - A = 1.0 / tmp_global - A_item = A.item() - A_err_item = (A * err).item() - inv_lam_item = (1.0 / lam).item() - - # 2. Local updates: utilize in-place torch.addr_ instead of allocating K @ K.T intermediates. - for i in range(weights_num): - K = K_list[i] - K_vec = K.squeeze(1) - - weights[i].add_(K, alpha=A_err_item) - P[i].addr_(K_vec, K_vec, beta=1.0, alpha=-A_item).mul_(inv_lam_item) - - # 3. Distributed Weights Gathering via Multicast broadcast - if dist.is_initialized(): - local_shape = [w.shape[0] for w in weights] - shape_tuple = tuple(local_shape) - - # Populate shape cache once per process to bypass repeated all_gather_object. - if shape_tuple not in _cache: - shape_list_t = [ - torch.zeros(len(local_shape), dtype=torch.int64, device=device) - for _ in range(world_size) - ] - local_shape_tensor = torch.tensor(local_shape, dtype=torch.int64, device=device) - dist.all_gather(shape_list_t, local_shape_tensor) - - shape_list = [t.tolist() for t in shape_list_t] - sizes = [sum(s) for s in shape_list] - max_size = max(sizes) - - CHUNK_ALIGN = 512 - padded_chunk = ((max_size + CHUNK_ALIGN - 1) // CHUNK_ALIGN) * CHUNK_ALIGN - - gather_tk = get_or_create_parallel_tensor( - ext, (world_size, padded_chunk), torch.bfloat16, multicast=True - ) - - _cache[shape_tuple] = (shape_list, sizes, padded_chunk, gather_tk) - - shape_list, sizes, padded_chunk, gather_tk = _cache[shape_tuple] - - rank = dist.get_rank() - local_size = sizes[rank] - - flat_weights = torch.cat([w.reshape(-1) for w in weights], dim=0).to(torch.bfloat16) - - # Scatter local segment linearly; zero out padding block. - gather_tk.data_[rank, :local_size] = flat_weights - if local_size < padded_chunk: - gather_tk.data_[rank, local_size:].zero_() - - ext.tk_all_gather(gather_tk, barrier_tk, padded_chunk) - - # Rematerialize split weights list using cached topology shapes. - result = [] - for r in range(world_size): - r_size = sizes[r] - r_data = gather_tk.data_[r, :r_size].to(dtype) - - r_shapes = shape_list[r] - splits = torch.split(r_data, r_shapes) - for s in splits: - result.append(s.reshape(-1, 1)) - - weights = result - - # 4. Decay Kalman factor using scalar promotion logic seamlessly matching original type logic. - kalman_lambda_next = kalman_nue * lam + 1.0 - kalman_nue - - return weights, P, kalman_lambda_next \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_parallelkittens.py deleted file mode 100755 index 0c6cb72..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_parallelkittens.py +++ /dev/null @@ -1,472 +0,0 @@ -""" -Distributed homogeneous GNN neighbor sampling optimized with ParallelKittens. - -Features: -- Device-side metadata exchange (counts) using TKParallelTensor PGL layouts and barriers. -- Fully fused native CUDA sampling (tk_sample_one_hop) eliminating Python loops. -- Fully vectorized O(1)-launch node routing, deduplication, and reply reassembly. -""" - -import os -from typing import List, Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for TK counts exchange & fused one-hop sampling -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include -#include -#include -#include -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace metadata_exchange { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 32; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - using parallel_layout = pgl, NUM_DEVICES, false>; - parallel_layout tensor; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G, const int64_t* send_counts) { - int j = threadIdx.x; - if (j < globals::NUM_DEVICES) { - // Write our send_count to the destination rank's tensor array at our dev_idx offset - G.tensor[j][G.dev_idx] = send_counts[j]; - } - __syncthreads(); - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace metadata_exchange - -// TK entry point for rapid peer-to-peer count exchange -void tk_exchange_counts( - kittens::py::TKParallelTensor &recv_tensor, // output [NUM_DEVICES] - kittens::py::TKParallelTensor &barrier, - torch::Tensor send_counts // input [NUM_DEVICES] -) { - metadata_exchange::globals G { - .tensor = kittens::py::parallel_tensor_to_pgl(recv_tensor), - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = recv_tensor.local_rank_ - }; - - kittens::py::launch_kernel( - G, send_counts.data_ptr() - ); -} - -// --------------------------------------------------------------------------- -// Fused device-side neighbor sampling kernels -// --------------------------------------------------------------------------- - -__global__ void degree_kernel( - const int64_t* input_nodes, int n, int k, - const int64_t* colptr, int64_t* counts -) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - int64_t v = input_nodes[i]; - int64_t start = colptr[v]; - int64_t end = colptr[v + 1]; - int64_t deg = end - start; - int64_t take = (k >= 0 && k < deg) ? k : deg; - counts[i] = take; - } -} - -__global__ void sample_kernel( - const int64_t* input_nodes, int n, int k, - const int64_t* colptr, const int64_t* row, - bool replace, const int64_t* cumsum, - int64_t* sampled_nodes, int64_t* sampled_edges, - unsigned long long seed -) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - int64_t v = input_nodes[i]; - int64_t start = colptr[v]; - int64_t deg = colptr[v + 1] - start; - int64_t take = (k >= 0 && k < deg) ? k : deg; - int64_t offset = cumsum[i]; - - if (take > 0) { - curandState state; - curand_init(seed, i, 0, &state); - - if (take == deg) { - // Keep all neighbors directly - for (int64_t j = 0; j < take; j++) { - sampled_nodes[offset + j] = row[start + j]; - sampled_edges[offset + j] = start + j; - } - } else if (replace) { - // Sample with replacement - for (int64_t j = 0; j < take; j++) { - int r = curand(&state) % deg; - sampled_nodes[offset + j] = row[start + r]; - sampled_edges[offset + j] = start + r; - } - } else { - // Sample without replacement (Floyd/rejection for small k, reservoir fallback) - if (take <= 128) { - int64_t selected[128]; - for (int64_t j = 0; j < take; j++) { - bool duplicate; - int64_t r; - do { - duplicate = false; - r = curand(&state) % deg; - for (int64_t m = 0; m < j; m++) { - if (selected[m] == r) { duplicate = true; break; } - } - } while (duplicate); - selected[j] = r; - sampled_nodes[offset + j] = row[start + r]; - sampled_edges[offset + j] = start + r; - } - } else { - // Reservoir fallback - for (int64_t j = 0; j < take; j++) { - sampled_nodes[offset + j] = row[start + j]; - sampled_edges[offset + j] = start + j; - } - for (int64_t j = take; j < deg; j++) { - int64_t r = curand(&state) % (j + 1); - if (r < take) { - sampled_nodes[offset + r] = row[start + j]; - sampled_edges[offset + r] = start + j; - } - } - } - } - } - } -} - -std::tuple tk_sample_one_hop( - torch::Tensor input_nodes, int k, - torch::Tensor colptr, torch::Tensor row, bool replace -) { - int n = input_nodes.numel(); - auto options = input_nodes.options(); - auto counts = torch::empty({n}, options); - - int threads = 256; - int blocks = (n + threads - 1) / threads; - - if (n > 0) { - degree_kernel<<>>( - input_nodes.data_ptr(), n, k, - colptr.data_ptr(), counts.data_ptr() - ); - } - - auto counts_cumsum = counts.cumsum(0); - int64_t total_samples = n > 0 ? counts_cumsum[-1].item() : 0; - - auto offsets = torch::empty({n}, options); - if (n > 0) { - offsets[0] = 0; - if (n > 1) offsets.slice(0, 1, n) = counts_cumsum.slice(0, 0, n - 1); - } - - auto nbr_tensor = torch::empty({total_samples}, options); - auto eid_tensor = torch::empty({total_samples}, options); - - if (n > 0 && total_samples > 0) { - // Simple seed, normally would be randomized - unsigned long long seed = 12345; - sample_kernel<<>>( - input_nodes.data_ptr(), n, k, - colptr.data_ptr(), row.data_ptr(), - replace, offsets.data_ptr(), - nbr_tensor.data_ptr(), eid_tensor.data_ptr(), - seed - ); - } - - return std::make_tuple(nbr_tensor, eid_tensor, counts); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_exchange_counts", &tk_exchange_counts); - m.def("tk_sample_one_hop", &tk_sample_one_hop); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gnn_sampling_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _remove_duplicates_gpu(out_node: torch.Tensor, node: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Fully device-resident deduplication replacing np.unique CPU trip.""" - num_nodes = node.numel() - if out_node.numel() == 0: - return out_node, node - - node_combined = torch.cat([node, out_node]) - unique_vals, inverse = torch.unique(node_combined, return_inverse=True) - - first_indices = torch.empty_like(unique_vals).fill_(node_combined.numel() + 1) - first_indices.scatter_reduce_( - 0, inverse, torch.arange(node_combined.numel(), device=node_combined.device), - reduce="amin", include_self=False - ) - - first_indices = first_indices.sort().values - node_new = node_combined[first_indices] - src_new = node_new[num_nodes:] - return src_new, node_new - - -def _relabel_neighborhood_gpu( - node: torch.Tensor, - dst_with_dupl: torch.Tensor, - node_with_dupl: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - if node_with_dupl.numel() == 0: - return node.new_empty(0), node.new_empty(0) - - assoc = torch.full((int(node.max().item()) + 1,), -1, dtype=torch.long, device=node.device) - assoc[node] = torch.arange(node.numel(), device=node.device) - row = assoc[node_with_dupl] - col = assoc[dst_with_dupl] - return row, col - - -def _exchange_nodes_opt( - send_nodes: torch.Tensor, - send_counts: torch.Tensor, - ext, tensor_tk, barrier_tk, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, torch.Tensor]: - world_size = dist.get_world_size(group) - device = send_nodes.device - - ext.tk_exchange_counts(tensor_tk, barrier_tk, send_counts) - recv_counts = tensor_tk.data_.clone()[:world_size] - - recv_nodes = torch.empty(int(recv_counts.sum().item()), dtype=torch.long, device=device) - - if send_nodes.numel() == 0 and recv_nodes.numel() == 0: - return recv_nodes, recv_counts - - dist.all_to_all_single( - recv_nodes, - send_nodes, - input_split_sizes=send_counts.cpu().tolist(), - output_split_sizes=recv_counts.cpu().tolist(), - group=group, - ) - return recv_nodes, recv_counts - - -def _exchange_replies_opt( - sampled_nodes: torch.Tensor, - sampled_edges: torch.Tensor, - sampled_counts: torch.Tensor, - recv_counts: torch.Tensor, - ext, tensor_tk, barrier_tk, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - world_size = dist.get_world_size(group) - device = sampled_nodes.device - recv_splits = recv_counts.cpu().tolist() - - send_node_counts = torch.empty(world_size, dtype=torch.long, device=device) - offset = 0 - for r, count in enumerate(recv_splits): - send_node_counts[r] = sampled_counts[offset : offset + count].sum() - offset += count - - # ThunderKittens peer-to-peer metric exchange - ext.tk_exchange_counts(tensor_tk, barrier_tk, send_node_counts) - reply_node_counts = tensor_tk.data_.clone()[:world_size] - - ext.tk_exchange_counts(tensor_tk, barrier_tk, recv_counts) - reply_count_counts = tensor_tk.data_.clone()[:world_size] - - reply_nodes = torch.empty(int(reply_node_counts.sum().item()), dtype=torch.long, device=device) - reply_edges = torch.empty_like(reply_nodes) - reply_counts = torch.empty(int(reply_count_counts.sum().item()), dtype=torch.long, device=device) - - dist.all_to_all_single( - reply_nodes, sampled_nodes, - input_split_sizes=send_node_counts.cpu().tolist(), - output_split_sizes=reply_node_counts.cpu().tolist(), group=group - ) - dist.all_to_all_single( - reply_edges, sampled_edges, - input_split_sizes=send_node_counts.cpu().tolist(), - output_split_sizes=reply_node_counts.cpu().tolist(), group=group - ) - dist.all_to_all_single( - reply_counts, sampled_counts, - input_split_sizes=recv_splits, - output_split_sizes=reply_count_counts.cpu().tolist(), group=group - ) - return reply_nodes, reply_edges, reply_counts - - -@torch.no_grad() -def solution( - seed_nodes: torch.Tensor, - fanouts: List[int], - local_adj_row_ptr: torch.Tensor, - local_adj_col: torch.Tensor, - node_to_rank: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, - replace: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - device = seed_nodes.device - - ext = _ensure_ext_jit() - # Cache TKParallelTensor for PGL symmetric memory (used to bypass count exchange overhead) - tensor_tk = get_or_create_parallel_tensor(ext, (world_size,), torch.int64, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - seed = seed_nodes.to(dtype=torch.long, device=device) - src = seed.clone() - node = src.clone() - node_with_dupl = [seed.new_empty(0)] - dst_with_dupl = [seed.new_empty(0)] - edge = [seed.new_empty(0)] - - for fanout in fanouts: - if src.numel() == 0: - break - - # Fast O(1) routing avoiding Python list appends - partition_ids = node_to_rank[src].to(torch.long) - sorted_part_ids, sorted_indices = torch.sort(partition_ids, stable=True) - send_nodes = src[sorted_indices] - send_counts = torch.bincount(partition_ids, minlength=world_size) - - recv_nodes, recv_counts = _exchange_nodes_opt( - send_nodes, send_counts, ext, tensor_tk, barrier_tk, group - ) - - # Fused one-hop kernel using custom PyBind11 backend call - sampled_nodes, edge_out, sampled_counts = ext.tk_sample_one_hop( - recv_nodes, int(fanout), local_adj_row_ptr, local_adj_col, replace - ) - - reply_nodes, reply_edges, reply_counts = _exchange_replies_opt( - sampled_nodes, edge_out, sampled_counts, recv_counts, ext, tensor_tk, barrier_tk, group - ) - - # Invert sorted_indices to reconstruct mapping map via stable vector permutations - grouped_index = torch.empty_like(sorted_indices) - grouped_index[sorted_indices] = torch.arange(sorted_indices.numel(), device=device) - - if reply_nodes.numel() > 0: - orig_counts = reply_counts[grouped_index] - orig_offsets = torch.empty_like(orig_counts) - orig_offsets[0] = 0 - if orig_offsets.numel() > 1: - orig_offsets[1:] = torch.cumsum(orig_counts[:-1], 0) - - reply_offsets = torch.empty_like(reply_counts) - reply_offsets[0] = 0 - if reply_offsets.numel() > 1: - reply_offsets[1:] = torch.cumsum(reply_counts[:-1], 0) - - chunk_indices = torch.repeat_interleave(torch.arange(reply_counts.numel(), device=device), reply_counts) - offset_within_chunk = torch.arange(reply_nodes.numel(), device=device) - reply_offsets[chunk_indices] - target_index = orig_offsets[sorted_indices[chunk_indices]] + offset_within_chunk - - out_node = torch.empty_like(reply_nodes) - out_node[target_index] = reply_nodes - - out_edge = torch.empty_like(reply_edges) - out_edge[target_index] = reply_edges - - out_dst = torch.repeat_interleave(src, orig_counts) - else: - out_node = seed.new_empty(0) - out_edge = seed.new_empty(0) - out_dst = seed.new_empty(0) - - if out_node.numel() == 0: - break - - src, node = _remove_duplicates_gpu(out_node, node) - node_with_dupl.append(out_node) - dst_with_dupl.append(out_dst) - edge.append(out_edge) - - node_dupl = torch.cat(node_with_dupl) - dst_dupl = torch.cat(dst_with_dupl) - row, col = _relabel_neighborhood_gpu(node, dst_dupl, node_dupl) - - return node, row, col, torch.cat(edge) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_parallelkittens.py deleted file mode 100755 index eb345f4..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_parallelkittens.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -Strategy: -1. **Compute-Communication Overlap**: We identify the maximum per-rank chunk size via a lightweight asynchronous integer `all_reduce` on the host. This completely overlaps with the device-side local feature gathering (`local_features[seed_inverse_ids]`). -2. **Layout Caching**: We round up the maximum message size to the nearest power of 2 (min 512). This ensures `ParallelKittens` TMA layouts and IPC handles remain statically cached across iterations, preventing costly reallocation syncs. -3. **TMA Device-Side Collectives**: Rather than multiple `all_to_all_single` calls on variable lengths, we pad the data and execute a single fixed-shape personalized all-to-all transpose using ThunderKittens' Tensor Memory Accelerator (TMA) routines. This moves data natively point-to-point without PyTorch dispatch overhead. -4. **Zero-Copy Routing**: We reconstruct the GraphBolt expected rotated rank ordering locally by scattering and gathering from the single padded buffer directly, matching the reference's `_shift` math exactly but keeping the hot path entirely in our custom extension layout. -""" - -import os -from typing import List, Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -def _next_power_of_2(x: int) -> int: - return 1 if x == 0 else 2**(x - 1).bit_length() - - -@torch.no_grad() -def solution( - local_features: torch.Tensor, - seed_inverse_ids: torch.Tensor, - counts_sent: List[int], - counts_received: List[int], - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - w = dist.get_world_size(group) - rank = dist.get_rank(group) - device = local_features.device - original_dtype = local_features.dtype - - assert w == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={w}" - ) - - # Overlap global max chunk calculation with local data gathering - max_c = max(max(counts_received, default=0), max(counts_sent, default=0)) - local_max = torch.tensor([max_c], dtype=torch.int32, device=device) - handle = dist.all_reduce(local_max, op=dist.ReduceOp.MAX, group=group, async_op=True) - - # 1. Gather local features mapped by seed_inverse_ids and split into destination chunks - gathered = local_features[seed_inverse_ids] - send_chunks = list(torch.split(gathered, counts_received)) - H = gathered.shape[1] - - # Wait for the all_reduce - handle.wait() - - # Pad to nearest power of 2 (min 512) to lock cache entries in ParallelKittens memory allocator - global_max_count = local_max.item() - padded_max_count = _next_power_of_2(max(global_max_count, 512)) - - # Construct the tensor containing pieces to be sent, padded universally - # input_tensor[dst_rank] = chunk to send to dst_rank - input_tensor = torch.zeros(w, padded_max_count, H, dtype=torch.bfloat16, device=device) - for i, chunk in enumerate(send_chunks): - dst_rank = (rank + i) % w - if chunk.shape[0] > 0: - input_tensor[dst_rank, :chunk.shape[0], :] = chunk.to(torch.bfloat16) - - # Layout shaping for ThunderKittens parallel matrix transpose (scatter_axis=0, gather_axis=1) - rest = padded_max_count * H - r, c, padded_rest = _padded_row_col(rest) - - padded = torch.zeros(w, padded_rest, dtype=torch.bfloat16, device=device) - padded[:, :rest] = input_tensor.view(w, rest) - inp_4 = padded.view(w, 1, r, c) - - ext = _ensure_ext_jit() - - input_tk = get_or_create_parallel_tensor( - ext, (w, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, w, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=w) - - n = inp_4.numel() - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - # Scatter by batch (destination index) and gather by depth (source rank index) - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - out_flat = ( - output_tk.data_.reshape(-1)[:n].view(1, w, r, c)[0].reshape(w, padded_rest)[:, :rest] - ) - out_tensor = out_flat.view(w, padded_max_count, H) - - # 4. Un-rotate and slice received data into expected GraphBolt flat concatenated sequence - recv_chunks = [] - for i in range(w): - src_rank = (rank + i) % w - valid_rows = counts_sent[i] - recv_chunks.append(out_tensor[src_rank, :valid_rows, :]) - - final_out = torch.cat(recv_chunks, dim=0) if recv_chunks else torch.empty((0, H), device=device) - - return final_out.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_parallelkittens.py deleted file mode 100755 index 9389d93..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_parallelkittens.py +++ /dev/null @@ -1,380 +0,0 @@ -import os -from typing import List, Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all TMA + scatter_add reduction) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void tk_all_to_all_entry( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -__global__ void scatter_add_kernel( - const __nv_bfloat16* __restrict__ out, - const int64_t* __restrict__ seed_inverse_ids, - __nv_bfloat16* __restrict__ grad_input, - int N_recv, - int H -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N_recv * H) { - int row = idx / H; - int col = idx % H; - int64_t dst_row = seed_inverse_ids[row]; - // Native Hopper hardware bfloat16 atomic add - atomicAdd(&grad_input[dst_row * H + col], out[idx]); - } -} - -void tk_scatter_add_entry( - torch::Tensor out, - torch::Tensor seed_inverse_ids, - torch::Tensor grad_input -) { - int N_recv = out.size(0); - int H = out.size(1); - int total_elems = N_recv * H; - if (total_elems == 0) return; - - int threads = 256; - int blocks = (total_elems + threads - 1) / threads; - - scatter_add_kernel<<>>( - reinterpret_cast(out.data_ptr()), - seed_inverse_ids.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(grad_input.data_ptr()), - N_recv, - H - ); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &tk_all_to_all_entry); - m.def("tk_scatter_add", &tk_scatter_add_entry); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gnn_all2all_bw_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _shift(chunks: List[torch.Tensor], group: dist.ProcessGroup) -> List[torch.Tensor]: - cutoff = len(chunks) - dist.get_rank(group) - return chunks[cutoff:] + chunks[:cutoff] - -def _unshift(chunks: List[torch.Tensor], group: dist.ProcessGroup) -> List[torch.Tensor]: - cutoff = dist.get_rank(group) - return chunks[cutoff:] + chunks[:cutoff] - - -@torch.no_grad() -def solution( - grad_output: torch.Tensor, - seed_inverse_ids: torch.Tensor, - seed_size: int, - counts_sent: List[int], - counts_received: List[int], - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - W = dist.get_world_size(group) - assert W == NUM_DEVICES, f"Expected world size of {NUM_DEVICES} for this extension" - - original_dtype = grad_output.dtype - H = grad_output.shape[1] if grad_output.dim() > 1 else 1 - ext = _ensure_ext_jit() - - # Determine global max to allow padding for ThunderKittens TMA symmetric-size requirements - local_max_elems = 0 - for c in counts_sent + counts_received: - local_max_elems = max(local_max_elems, c * H) - - local_max_tensor = torch.tensor([local_max_elems], dtype=torch.long, device=grad_output.device) - dist.all_reduce(local_max_tensor, op=dist.ReduceOp.MAX, group=group) - global_max_elems = local_max_tensor.item() - - # Map padded tensor block to TK TMA compatible sizing parameters - TILE_ELEMS = 16 * 128 - num_tiles = (global_max_elems + TILE_ELEMS - 1) // TILE_ELEMS - if num_tiles == 0: - num_tiles = 1 - r_dim = 16 - c_dim = 128 * num_tiles - padded_chunk_size = r_dim * c_dim - - padded_send = torch.zeros((W, padded_chunk_size), dtype=torch.bfloat16, device=grad_output.device) - - inputs = list(torch.split(grad_output, counts_sent)) - shifted_inputs = _shift(inputs, group) - - # Copy shifted input fragments into padded send blocks - for i, inp in enumerate(shifted_inputs): - n = inp.numel() - if n > 0: - padded_send[i, :n] = inp.to(torch.bfloat16).view(-1) - - input_tk = get_or_create_parallel_tensor( - ext, (W, 1, r_dim, c_dim), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, W, r_dim, c_dim), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=W) - - n_total = W * padded_chunk_size - input_tk.data_.view(-1)[:n_total].copy_(padded_send.view(-1)) - - # 1. Device-side asymmetric TMA all-to-all - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - padded_recv = output_tk.data_.view(-1)[:n_total].view(W, padded_chunk_size) - shifted_counts_received = _shift(counts_received, group) - - shifted_outputs = [] - for i, count in enumerate(shifted_counts_received): - n = count * H - if n > 0: - shifted_outputs.append(padded_recv[i, :n].view(count, H)) - else: - shifted_outputs.append(torch.empty((0, H), dtype=torch.bfloat16, device=grad_output.device)) - - # Unshift back to natural rotation logic of GraphBolt outputs - outputs = _unshift(shifted_outputs, group) - if sum(counts_received) > 0: - flat_out = torch.cat(outputs) - else: - flat_out = torch.empty((0, H), dtype=torch.bfloat16, device=grad_output.device) - - grad_input = torch.zeros((seed_size, H), dtype=torch.bfloat16, device=grad_output.device) - - # 2. Custom fast scatter-add exploiting native bfloat16 atomic reductions (avoids sparse tensor overhead) - ext.tk_scatter_add(flat_out, seed_inverse_ids, grad_input) - - return grad_input.to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_parallelkittens.py deleted file mode 100755 index 9fba31e..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_parallelkittens.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Strategy: -1. Overlap split exchange with local compute: We compute the histogram of destination ranks (`send_splits`), and asynchronously `all_gather` these splits across all ranks while independently sorting the local data (`owner`, `idx`, `value`) to group it by destination rank. -2. Global deterministic offsets: Since the all-gather gives every rank the complete communication matrix, each rank can deterministically compute the exact read/write offsets for every peer without any further coordinate exchange. -3. Direct Device-to-Device (P2P) Push: We allocate a symmetric buffer (`TKParallelTensor`) large enough for the maximum received size, padded to a power of 2 for VMM reuse. A custom ThunderKittens CUDA kernel directly pushes each peer's chunk into its exact final offset in the peer's destination buffer using P2P NVLink stores. -4. Barrier & Slice: A single device-side barrier ensures all P2P writes have landed. Each rank simply slices the exact number of elements it received, avoiding all variable-length padding artifacts and host synchronization overheads associated with NCCL's variable-length `all_to_all_single`. -""" - -import os -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Direct P2P Push -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -// Barrier namespace -namespace tk_barrier { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; - }; - - struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; - }; - - __device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); - } -} - -// Device-to-Device (P2P) Push Kernel -// Each block handles copying the packed chunk for a specific peer directly into their destination memory. -template -__global__ void p2p_push_kernel( - T* dst_0, T* dst_1, T* dst_2, T* dst_3, - T* dst_4, T* dst_5, T* dst_6, T* dst_7, - const T* src_data, - const int* src_offsets, - const int* dst_offsets, - const int* counts, - int D -) { - T* dst_ptrs[8] = {dst_0, dst_1, dst_2, dst_3, dst_4, dst_5, dst_6, dst_7}; - - int dst_rank = blockIdx.y; - int count = counts[dst_rank]; - if (count == 0) return; - - int total_elems = count * D; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - - T* dst = dst_ptrs[dst_rank]; - int src_off = src_offsets[dst_rank] * D; - int dst_off = dst_offsets[dst_rank] * D; - - // Fast path: use vectorized uint4 copies if aligned - if (reinterpret_cast(dst) % 16 == 0 && - reinterpret_cast(src_data) % 16 == 0 && - src_off % (16 / sizeof(T)) == 0 && - dst_off % (16 / sizeof(T)) == 0 && - total_elems % (16 / sizeof(T)) == 0) { - - int4* dst_vec = reinterpret_cast(dst + dst_off); - const int4* src_vec = reinterpret_cast(src_data + src_off); - int vec_elems = total_elems / (16 / sizeof(T)); - - for (int i = tid; i < vec_elems; i += stride) { - dst_vec[i] = src_vec[i]; - } - } else { - // Fallback flat copy - for (int i = tid; i < total_elems; i += stride) { - dst[dst_off + i] = src_data[src_off + i]; - } - } -} - -void entrypoint( - kittens::py::TKParallelTensor &tk_dst_idx, - kittens::py::TKParallelTensor &tk_dst_val, - torch::Tensor src_idx, - torch::Tensor src_val, - torch::Tensor src_offsets, - torch::Tensor dst_offsets, - torch::Tensor counts, - kittens::py::TKParallelTensor &barrier, - int D_idx, - int D_val -) { - int dev_idx = tk_dst_idx.local_rank_; - - tk_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = dev_idx - }; - - // Extract raw pointers for index (int64_t) - int64_t* dst_idx_ptrs[8]; - for(int i=0; i<8; ++i) { - dst_idx_ptrs[i] = reinterpret_cast(tk_dst_idx.data_ptrs_[i]); - } - - // Extract raw pointers for values (__nv_bfloat16) - __nv_bfloat16* dst_val_ptrs[8]; - for(int i=0; i<8; ++i) { - dst_val_ptrs[i] = reinterpret_cast<__nv_bfloat16*>(tk_dst_val.data_ptrs_[i]); - } - - // 1. Barrier before write to ensure destination VMM buffers are ready and untouched by previous rounds - kittens::py::launch_kernel(barrier_G); - - // 2. Launch NVLink direct P2P writes - dim3 grid(32, 8); // 32 blocks per peer, 8 peers (y-dimension) - dim3 block(256); - - if (src_idx.numel() > 0) { - p2p_push_kernel<<>>( - dst_idx_ptrs[0], dst_idx_ptrs[1], dst_idx_ptrs[2], dst_idx_ptrs[3], - dst_idx_ptrs[4], dst_idx_ptrs[5], dst_idx_ptrs[6], dst_idx_ptrs[7], - src_idx.data_ptr(), - src_offsets.data_ptr(), - dst_offsets.data_ptr(), - counts.data_ptr(), - D_idx - ); - } - - if (src_val.numel() > 0) { - p2p_push_kernel<__nv_bfloat16><<>>( - dst_val_ptrs[0], dst_val_ptrs[1], dst_val_ptrs[2], dst_val_ptrs[3], - dst_val_ptrs[4], dst_val_ptrs[5], dst_val_ptrs[6], dst_val_ptrs[7], - reinterpret_cast(src_val.data_ptr()), - src_offsets.data_ptr(), - dst_offsets.data_ptr(), - counts.data_ptr(), - D_val - ); - } - - // 3. Barrier after write to ensure all data has landed before the host attempts to slice - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_p2p_push", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_p2p_push_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _get_ext() - if dist.is_initialized(): - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - idx: torch.Tensor, - value: torch.Tensor, - num_nodes: int, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return idx, value - - assert world_size == 8, "This ThunderKittens integration expects exactly 8 ranks per node." - me = dist.get_rank(group) - - idx_orig_dtype = idx.dtype - if idx.dtype != torch.int64: - idx = idx.to(torch.int64) - - # 1. Bucket local updates (launch stream: default) - owner = (idx % world_size).long() - send_splits = torch.bincount(owner, minlength=world_size) - - # 2. Async All-Gather of splits. Exchanges only 8 elements per rank. - all_send_splits = torch.empty(world_size, world_size, dtype=torch.long, device=idx.device) - gather_work = dist.all_gather_into_tensor(all_send_splits, send_splits, group=group, async_op=True) - - # 3. Overlap compute while Gather is in-flight: Sort & Pack based on destination rank. - perm = torch.argsort(owner, stable=True) - send_idx_packed = idx[perm].contiguous() - send_value_packed = value[perm] - - # 4. Wait for split sizes. Since all ranks now have the FULL exchange matrix, offsets are deterministic. - gather_work.wait() - - # Derived read/write coordinates entirely via local math: - src_offsets = torch.zeros(world_size, dtype=torch.int32, device=idx.device) - if world_size > 1: - src_offsets[1:] = torch.cumsum(all_send_splits[me, :-1], dim=0).to(torch.int32) - - dst_offsets = all_send_splits[:me, :].sum(dim=0).to(torch.int32).contiguous() - counts = all_send_splits[me, :].to(torch.int32).contiguous() - - my_recv_count = int(all_send_splits[:, me].sum().item()) - global_max_recv = int(all_send_splits.sum(dim=0).max().item()) - - if global_max_recv == 0: - return ( - torch.empty((0,), dtype=idx_orig_dtype, device=idx.device), - torch.empty((0, *value.shape[1:]), dtype=value.dtype, device=value.device) - ) - - # 5. Get TK symmetric memory. Pad to next power of 2 for stable caching of Virtual Memory allocations. - ext = _ensure_ext_jit() - padded_max_recv = max(256, 1 << (global_max_recv - 1).bit_length()) - - tk_dst_idx = get_or_create_parallel_tensor(ext, (padded_max_recv,), torch.int64, multicast=False) - - D = value.shape[1:].numel() if value.dim() > 1 else 1 - val_dtype = torch.bfloat16 - send_value_packed_bf16 = send_value_packed.to(val_dtype).reshape(-1, D).contiguous() - tk_dst_val = get_or_create_parallel_tensor(ext, (padded_max_recv, D), val_dtype, multicast=False) - - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # 6. Execute direct device-to-device push (bypassing PyTorch host looping and variable-size collectives) - ext.tk_p2p_push( - tk_dst_idx, - tk_dst_val, - send_idx_packed, - send_value_packed_bf16, - src_offsets, - dst_offsets, - counts, - barrier_tk, - 1, - D - ) - - # 7. Zero-overhead slice to exact boundary based on deterministically computed sizes - recv_idx = tk_dst_idx.data_[:my_recv_count].clone().to(idx_orig_dtype) - recv_value = tk_dst_val.data_[:my_recv_count].reshape(my_recv_count, *value.shape[1:]).clone().to(value.dtype) - - return recv_idx, recv_value \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_parallelkittens.py deleted file mode 100755 index 0b717a3..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_parallelkittens.py +++ /dev/null @@ -1,287 +0,0 @@ -""" -Strategy: -1. **Device-Side P2P Communication**: We eliminate three PyTorch `all_to_all` passes and two `argsort` operations by exposing the local embedding shards via ThunderKittens symmetric `TKParallelTensor` (`pgl`). A custom Hopper CUDA kernel executes direct peer-to-peer (P2P) loads over NVLink, writing directly to the requested output index. This natively preserves the query order without any host-side coordination. -2. **Compute-Communication Overlap**: The sparse gather is highly memory-bandwidth bound over NVLink, while the embedding projection is a dense math-bound Tensor Core GEMM. We split the queries into chunks and use CUDA stream double-buffering: while Chunk $i$ executes its dense projection, Chunk $i+1$ performs the NVLink P2P gather concurrently, entirely hiding the communication latency. -""" - -import os -from typing import Optional - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source: P2P Gather + Barries using ThunderKittens -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -// ============================================================================ -// Barrier Module -// ============================================================================ -namespace barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace barrier - -void tk_barrier(kittens::py::TKParallelTensor &barrier) { - barrier::globals G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(G); -} - - -// ============================================================================ -// NVLink P2P Gather Kernel -// ============================================================================ -namespace p2p_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_WARPGROUPS = 1; - static constexpr int NUM_WARPS = 4; - static constexpr int NUM_THREADS = NUM_WARPS * 32; // 128 threads - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - // Flat PGL layout: Multicast is false, we just use this to cleanly resolve symmetric memory pointers - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout embeddings; - const int64_t* input_node_ids; - bf16* output_gathered; - - int num_queries; - int shard_size; - int embed_dim; - - __host__ inline dim3 grid() const { - return dim3((num_queries + config::NUM_WARPS - 1) / config::NUM_WARPS); - } -}; - -__device__ inline void kernel(const globals &G) { - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - - int q_idx = blockIdx.x * config::NUM_WARPS + warp_id; - if (q_idx >= G.num_queries) return; - - int64_t node_id = G.input_node_ids[q_idx]; - int owner = node_id / G.shard_size; - int local_id = node_id % G.shard_size; - - // Clamp to valid range (should be guaranteed by inputs) - if (owner >= globals::NUM_DEVICES) owner = globals::NUM_DEVICES - 1; - - // Use the pointer resolved by the ParallelKittens PGL broker - bf16* src_ptr = (bf16*)G.embeddings[owner].data; - bf16* dst_ptr = G.output_gathered + q_idx * G.embed_dim; - - // Warp-level vectorized read from peer GPU memory - for (int d = lane_id; d < G.embed_dim; d += 32) { - dst_ptr[d] = src_ptr[local_id * G.embed_dim + d]; - } -} - -} // namespace p2p_gather - -__global__ void __launch_bounds__(p2p_gather::config::NUM_THREADS, 1) -gather_kernel_wrapper(p2p_gather::globals G) { - p2p_gather::kernel(G); -} - -void tk_p2p_gather( - kittens::py::TKParallelTensor &embeddings, - torch::Tensor input_node_ids, - torch::Tensor output_gathered, - int num_queries, - int shard_size, - int embed_dim -) { - p2p_gather::globals G { - .embeddings = kittens::py::parallel_tensor_to_pgl(embeddings), - .input_node_ids = input_node_ids.data_ptr(), - .output_gathered = reinterpret_cast(output_gathered.data_ptr()), - .num_queries = num_queries, - .shard_size = shard_size, - .embed_dim = embed_dim - }; - - // Grab the current PyTorch stream to execute asynchronously and enable overlapping - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid = G.grid(); - dim3 block(p2p_gather::config::NUM_THREADS); - - gather_kernel_wrapper<<>>(G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_p2p_gather", &tk_p2p_gather); - m.def("tk_barrier", &tk_barrier); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gnn_p2p_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - local_embedding_shard: torch.Tensor, - input_node_ids: torch.Tensor, - proj_matrix: torch.Tensor, - num_total_nodes: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - assert world_size == 8, f"This kernel is built for 8 GPUs, got world_size={world_size}" - - ext = _ensure_ext_jit() - - shard_size = (num_total_nodes + world_size - 1) // world_size - embed_dim = local_embedding_shard.shape[1] - - # 1. Provide symmetric access to embedding shards using TKParallelTensor cache - embeddings_tk = get_or_create_parallel_tensor( - ext, (shard_size, embed_dim), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - # Flat copy over (resolving size if the rank owns slightly fewer than `shard_size` items) - n = local_embedding_shard.numel() - embeddings_tk.data_.reshape(-1)[:n].copy_(local_embedding_shard.reshape(-1)) - - # Force synchronization so all peers can safely pull from this array - ext.tk_barrier(barrier_tk) - - num_queries = input_node_ids.shape[0] - out_dim = proj_matrix.shape[1] - - # 2. Compute-Communication Overlap via Double-Buffered Streams - num_chunks = 2 - chunk_size = (num_queries + num_chunks - 1) // num_chunks - streams = [torch.cuda.Stream() for _ in range(num_chunks)] - out_chunks = [] - - for i in range(num_chunks): - start = i * chunk_size - end = min(start + chunk_size, num_queries) - if start >= end: - break - - with torch.cuda.stream(streams[i]): - queries_chunk = input_node_ids[start:end] - gathered_chunk = torch.empty( - (end - start, embed_dim), - dtype=torch.bfloat16, - device=input_node_ids.device - ) - - # Asynchronous custom P2P read kernel - ext.tk_p2p_gather( - embeddings_tk, - queries_chunk, - gathered_chunk, - end - start, - shard_size, - embed_dim - ) - - # Immediately perform dense projection while the next gather runs - out_chunk = torch.matmul(gathered_chunk, proj_matrix) - out_chunks.append(out_chunk) - - # Resolve pipelines - for s in streams: - torch.cuda.current_stream().wait_stream(s) - - out = torch.cat(out_chunks, dim=0) if out_chunks else torch.empty( - (0, out_dim), dtype=proj_matrix.dtype, device=proj_matrix.device - ) - - # Protect embeddings_tk buffer from being recycled or overwritten before all remote peers finish reads - ext.tk_barrier(barrier_tk) - - return out \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/6_gather_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/6_gather_parallelkittens.py deleted file mode 100755 index 3ff2198..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/6_gather_parallelkittens.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -ThunderKittens Gather via direct TMA push to destination. - -Optimized for 8x H100 (Hopper) connected via NVLink. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - int dst_dev_idx; - int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE)); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -__device__ inline void kernel(const globals &G) { - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx % (G.input.cols() / globals::COL_BLOCK_SIZE); - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - - // Load local input block - tma::load_async(tile, G.input[G.dev_idx], {0, 0, row_block_idx, col_block_idx}, arrived); - wait(arrived, 0); - - // Push block directly to the destination rank - tma::store_async(G.output[G.dst_dev_idx], tile, - {G.dev_idx, 0, row_block_idx, col_block_idx}); -} - -} // namespace gather - -namespace gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int dst -) { - TORCH_CHECK(0 <= dst && dst < gather::globals::NUM_DEVICES, "dst rank must be valid"); - kittens::py::parallel_tensor_check(output, input); - - gather::globals gather_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dst_dev_idx = dst, - .dev_idx = input.local_rank_ - }; - - gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Synchronize to ensure all ranks are ready to TMA push - kittens::py::launch_kernel(barrier_G); - - // Execute gather (direct TMA store to dst) - kittens::py::launch_kernel(gather_G); - - // Synchronize to ensure all TMA stores have completed system-wide before host access - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gather_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -@torch.no_grad() -def solution( - tensor: torch.Tensor, - dst: int = 0, -) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - - flat = tensor.to(torch.bfloat16).reshape(-1).contiguous() - rest = flat.numel() - - # Calculate padding to map data correctly into TMA hardware tiles (16 x 128) - r, c, padded_rest = _padded_row_col(rest) - - padded = torch.zeros(padded_rest, dtype=torch.bfloat16, device=tensor.device) - padded[:rest] = flat - inp_4 = padded.view(1, 1, r, c) - - # Establish PGL VMM layouts (shape symmetric across all NVLink ranks) - input_tk = get_or_create_parallel_tensor( - ext, (1, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (world, 1, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Fill current rank's chunk - n = inp_4.numel() - input_tk.data_.reshape(-1)[:n].copy_(inp_4.reshape(-1)) - - # Kernel execution: load local chunk, push via TMA to output_tk on rank `dst` - ext.tk_gather(output_tk, input_tk, barrier_tk, dst) - - rank = dist.get_rank() - if rank == dst: - # Destination rank: Extract stacked slices avoiding alignment padding - out_n = output_tk.data_.reshape(-1)[:world * n] - out_flat = out_n.view(world, padded_rest)[:, :rest].contiguous() - return out_flat.reshape(world, *original_shape).to(original_dtype) - else: - # Non-destination ranks: return input tensor unchanged - return tensor \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_parallelkittens.py deleted file mode 100755 index 91d0cf2..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_parallelkittens.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Distributed link-prediction ranking over positive and negative scores. -Optimized with ThunderKittens device-side compute and NVSwitch Multicast. - -Strategy: -- Mathematical equivalence: Local independent ranking followed by a 1D AllGather is equivalent - to a 2D AllGather followed by global ranking. This slashes communication from O(P * K) to O(P). -- Device-side fusion: A ThunderKittens custom kernel uses a warp-per-row reduction to efficiently - compute the rankings directly from the local BFloat16 tensors. -- Overlap & Multicast: The kernel natively writes the final ranking using inline NVSwitch - multicast (`st.global.nc.mcast`), overlapping the local compute completely with the collective - distribution of the results. -""" - -from typing import Optional -import os - -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace compute_gather { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 256; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const __nv_bfloat16* pos_scores; - const __nv_bfloat16* neg_scores; - int P; - int max_P; - int K; - const int dev_idx; - - __host__ inline dim3 grid() const { - int rows_per_block = config::NUM_THREADS / 32; - return dim3((max_P + rows_per_block - 1) / rows_per_block); - } -}; - -__device__ inline void kernel(const globals &G) { - int warp_id = threadIdx.x / 32; - int lane = threadIdx.x % 32; - int row_idx = blockIdx.x * (blockDim.x / 32) + warp_id; - - if (row_idx < G.max_P) { - float rank_val = 0.0f; - // Compute ranking locally using mathematical equivalence - // PyTorch's descending stable sort implies pos_score stays ahead of equal elements. - // Therefore, position is just 1 + (number of strictly greater negatives). - if (row_idx < G.P) { - float pos = __bfloat162float(G.pos_scores[row_idx]); - int count = 0; - // Coalesced warp access over the negative samples - for(int k = lane; k < G.K; k += 32) { - float neg = __bfloat162float(G.neg_scores[row_idx * G.K + k]); - if (neg > pos) { - count++; - } - } - - // Warp reduction - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - count += __shfl_down_sync(0xffffffff, count, offset); - } - - if (lane == 0) { - rank_val = (float)(count + 1); - } - } - - // Directly multicast the result to all GPUs, fusing communication with compute - if (lane == 0) { - int my_offset = G.dev_idx * G.max_P + row_idx; - asm volatile("st.global.nc.mcast.f32 [%0], %1;\n" - :: "l"(&G.tensor.mc_ptr[my_offset]), "f"(rank_val) - : "memory"); - } - } -} - -} // namespace compute_gather - -namespace compute_gather_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace compute_gather_barrier - -void entrypoint( - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier, - uintptr_t pos_scores_ptr, - uintptr_t neg_scores_ptr, - int P, - int max_P, - int K -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - compute_gather::globals G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .pos_scores = reinterpret_cast(pos_scores_ptr), - .neg_scores = reinterpret_cast(neg_scores_ptr), - .P = P, - .max_P = max_P, - .K = K, - .dev_idx = tensor.local_rank_ - }; - - compute_gather_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_compute_gather", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_gnn_ranking_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - if dist.is_initialized(): - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - local_pos_scores: torch.Tensor, - local_neg_scores: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - if world_size == 1: - scores = torch.cat([local_pos_scores.view(-1, 1), local_neg_scores], dim=1) - _, indices = torch.sort(torch.sigmoid(scores), dim=1, descending=True) - return torch.nonzero(indices == 0)[:, 1].view(-1).detach() + 1 - - assert world_size == 8, f"ThunderKittens kernel hardcoded for 8 devices, got {world_size}" - - assert local_pos_scores.dtype == torch.bfloat16, "Expected bfloat16 pos_scores" - assert local_neg_scores.dtype == torch.bfloat16, "Expected bfloat16 neg_scores" - - local_pos_scores = local_pos_scores.contiguous() - local_neg_scores = local_neg_scores.contiguous() - - ext = _ensure_ext_jit() - - P = local_pos_scores.shape[0] - K = local_neg_scores.shape[1] - - # Quick sync to determine total padded shape for the communication collective - sizes = torch.zeros(world_size, dtype=torch.long, device=local_pos_scores.device) - sizes[rank] = P - dist.all_reduce(sizes, op=dist.ReduceOp.SUM, group=group) - - max_P = sizes.max().item() - # Pad aggressively to avoid constant VMM reallocation when graphs fluctuate slightly - max_P_padded = ((max_P + 4095) // 4096) * 4096 - - tensor_tk = get_or_create_parallel_tensor( - ext, (world_size * max_P_padded,), torch.float32, multicast=True - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world_size) - - pos_ptr = local_pos_scores.data_ptr() - neg_ptr = local_neg_scores.data_ptr() - - # Launch fusion kernel: locally computes independent rankings & NVSwitch multicasts immediately - ext.tk_compute_gather(tensor_tk, barrier_tk, pos_ptr, neg_ptr, P, max_P_padded, K) - - # Reconstruct exact GraphStorm format: concatenate valid sub-slices from each rank's multicast payload - data_view = tensor_tk.data_.view(world_size, max_P_padded) - sizes_list = sizes.tolist() - - res = [data_view[i, :sizes_list[i]] for i in range(world_size)] - return torch.cat(res, dim=0).to(torch.long) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_parallelkittens.py deleted file mode 100755 index 207b35a..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_parallelkittens.py +++ /dev/null @@ -1,478 +0,0 @@ -""" -ThunderKittens KJT All-To-All Pull-Permute. - -Replaces the NCCL all-to-all and PyTorch segment-gather permutation loops with -a fused PGL NVLink kernel. The kernel directly pulls variable-sized jagged segments -from peer workspaces into their final permuted output locations. -""" - -import os -from typing import Dict, List, Optional - -import torch -import torch.distributed as dist - -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source for Pull-Permute Kernels -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace barrier_ns { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 1; - }; - struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; - }; - __device__ void kernel(const globals& G) { - barrier_all(G.barrier, {0}, G.dev_idx); - } -} - -template -struct pull_permute_globals { - static constexpr int NUM_DEVICES = 8; - // PGL array of rank memory blocks - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout send_workspace; - T* out; - - const int* dst_offsets; - const int* peer_ids; - const int* src_offsets; - - int num_segments; - int total_elems; - int alloc_size; - int rank; - - __host__ inline dim3 grid() const { - return dim3((total_elems + 255) / 256); - } -}; - -namespace pull_permute { - struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_THREADS = 256; - }; -} - -template -__device__ inline void pull_permute_kernel(const pull_permute_globals &G) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= G.total_elems) return; - - // Binary search to find which segment this thread's element belongs to - int low = 0, high = G.num_segments - 1; - while (low < high) { - int mid = (low + high + 1) / 2; - if (G.dst_offsets[mid] <= tid) { - low = mid; - } else { - high = mid - 1; - } - } - int seg = low; - - int peer_id = G.peer_ids[seg]; - int offset_in_seg = tid - G.dst_offsets[seg]; - int src_offset = G.src_offsets[seg] + offset_in_seg; - - // Direct NVLink read from the peer's workspace - const T* src_base = reinterpret_cast(G.send_workspace[peer_id].data); - - // The workspace on the peer is structured as 8 blocks of size `alloc_size` - G.out[tid] = src_base[G.rank * G.alloc_size + src_offset]; -} - -void tk_barrier_entry(kittens::py::TKParallelTensor &barrier) { - barrier_ns::globals G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(G); -} - -void entrypoint_bf16( - kittens::py::TKParallelTensor &workspace, - torch::Tensor &out, - torch::Tensor &dst_offsets, - torch::Tensor &peer_ids, - torch::Tensor &src_offsets, - int num_segments, - int total_elems, - int alloc_size, - int rank -) { - pull_permute_globals G { - .send_workspace = kittens::py::parallel_tensor_to_pgl::parallel_layout>(workspace), - .out = reinterpret_cast(out.data_ptr()), - .dst_offsets = dst_offsets.data_ptr(), - .peer_ids = peer_ids.data_ptr(), - .src_offsets = src_offsets.data_ptr(), - .num_segments = num_segments, - .total_elems = total_elems, - .alloc_size = alloc_size, - .rank = rank - }; - kittens::py::launch_kernel, pull_permute_kernel>(G); -} - -void entrypoint_int32( - kittens::py::TKParallelTensor &workspace, - torch::Tensor &out, - torch::Tensor &dst_offsets, - torch::Tensor &peer_ids, - torch::Tensor &src_offsets, - int num_segments, - int total_elems, - int alloc_size, - int rank -) { - pull_permute_globals G { - .send_workspace = kittens::py::parallel_tensor_to_pgl::parallel_layout>(workspace), - .out = out.data_ptr(), - .dst_offsets = dst_offsets.data_ptr(), - .peer_ids = peer_ids.data_ptr(), - .src_offsets = src_offsets.data_ptr(), - .num_segments = num_segments, - .total_elems = total_elems, - .alloc_size = alloc_size, - .rank = rank - }; - kittens::py::launch_kernel, pull_permute_kernel>(G); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_barrier", &tk_barrier_entry); - m.def("tk_pull_permute_bf16", &entrypoint_bf16); - m.def("tk_pull_permute_int32", &entrypoint_int32); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", "--use_fast_math", "--expt-extended-lambda", "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", "-Xcompiler=-fno-strict-aliasing", "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False -_tk_workspaces = {} - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_kjt_all2all", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _ensure_workspace(name: str, dtype: torch.dtype, req_size: int, ext): - global _tk_workspaces - if name not in _tk_workspaces or _tk_workspaces[name].data_.shape[-1] < req_size: - alloc_size = max(req_size, 4 * 1024 * 1024) - _tk_workspaces[name] = get_or_create_parallel_tensor( - ext, (8, alloc_size), dtype, multicast=False - ) - return _tk_workspaces[name] - - -def _get_recat( - local_split: int, - num_splits: int, - stagger: int = 1, - device: Optional[torch.device] = None, - batch_size_per_rank: Optional[List[int]] = None, -) -> Optional[torch.Tensor]: - if local_split == 0: - return None - feature_order = [ - x + num_splits // stagger * y - for x in range(num_splits // stagger) - for y in range(stagger) - ] - if batch_size_per_rank is None: - recat = [ - feature_idx + rank_idx * local_split - for feature_idx in range(local_split) - for rank_idx in feature_order - ] - else: - rank_offsets = [0] - for batch_size in batch_size_per_rank[:-1]: - rank_offsets.append(rank_offsets[-1] + local_split * batch_size) - recat = [ - rank_offsets[rank_idx] + feature_idx * batch_size_per_rank[rank_idx] + b - for feature_idx in range(local_split) - for rank_idx in feature_order - for b in range(batch_size_per_rank[rank_idx]) - ] - return torch.tensor(recat, device=device, dtype=torch.long) - - -def _sum_by_splits(values: List[int], splits: List[int]) -> List[int]: - out: List[int] = [] - offset = 0 - for split in splits: - out.append(sum(values[offset : offset + split])) - offset += split - return out - - -def _lengths_per_key(lengths: torch.Tensor, stride_per_key: List[int]) -> List[int]: - out: List[int] = [] - offset = 0 - for stride in stride_per_key: - out.append(int(lengths[offset : offset + stride].sum().item())) - offset += stride - return out - - -@torch.no_grad() -def solution( - lengths: torch.Tensor, - values: torch.Tensor, - key_splits: List[int], - batch_size: int, - pg: Optional[dist.ProcessGroup] = None, - weights: Optional[torch.Tensor] = None, - stride_per_key: Optional[List[int]] = None, - stagger: int = 1, -) -> Dict[str, torch.Tensor]: - pg = pg or dist.group.WORLD - world_size = dist.get_world_size(pg) - rank = dist.get_rank(pg) - device = lengths.device - ext = _ensure_ext_jit() - - num_features = sum(key_splits) - variable_stride = stride_per_key is not None - if stride_per_key is None: - stride_per_key = [batch_size] * num_features - - length_per_key = _lengths_per_key(lengths, stride_per_key) - length_splits = _sum_by_splits(stride_per_key, key_splits) - value_splits = _sum_by_splits(length_per_key, key_splits) - - input_splits = [length_splits, value_splits] - if variable_stride: - input_splits.append(key_splits) - - split_tensors = [ - torch.tensor(splits, dtype=torch.long, device=device) for splits in input_splits - ] - if not variable_stride: - split_tensors.append( - torch.full((world_size,), batch_size, dtype=torch.long, device=device) - ) - if variable_stride: - stride_per_key_tensor = torch.tensor(stride_per_key, dtype=torch.long, device=device) - - meta_input = torch.stack(split_tensors, dim=1).flatten() - meta_output = torch.empty_like(meta_input) - dist.all_to_all_single(meta_output, meta_input, group=pg) - - meta_rows = [ - [int(item) for item in row] - for row in meta_output.view(world_size, -1).T.tolist() - ] - - if variable_stride: - output_splits = meta_rows - stride_per_rank = None - # output_splits[2] is received stride_per_key arrays from each rank - recv_strides_t = torch.tensor(output_splits[2], device=device, dtype=torch.long) - else: - output_splits = meta_rows[:-1] - stride_per_rank = meta_rows[-1] - - # Pre-allocate distributed TK parallel tensors for P2P routing - max_l_req = max(length_splits) - max_v_req = max(value_splits) - - max_reqs = torch.tensor([max_l_req, max_v_req], device=device, dtype=torch.long) - dist.all_reduce(max_reqs, op=dist.ReduceOp.MAX) - global_req_l, global_req_v = max_reqs.tolist() - - ws_l = _ensure_workspace("lengths", torch.int32, global_req_l, ext) - ws_v = _ensure_workspace("values", torch.bfloat16, global_req_v, ext) - ws_w = _ensure_workspace("weights", torch.bfloat16, global_req_v, ext) if weights is not None else None - ws_barrier = get_or_create_barrier(ext, num_devices=world_size) - - # 1. Scatter payloads locally into the outbound P2P registered windows - l_start, v_start = 0, 0 - lengths_i32 = lengths.to(torch.int32) - for j in range(world_size): - l_len, v_len = length_splits[j], value_splits[j] - ws_l.data_[j, :l_len].copy_(lengths_i32[l_start : l_start + l_len]) - ws_v.data_[j, :v_len].copy_(values[v_start : v_start + v_len]) - if weights is not None: - ws_w.data_[j, :v_len].copy_(weights[v_start : v_start + v_len]) - l_start += l_len - v_start += v_len - - local_split = key_splits[rank] - - # 2. Determine generalized Length Segments layout - if variable_stride: - recat_lengths = _get_recat(local_split, world_size, stagger, device=device) - if recat_lengths is None: - recat_lengths = torch.arange(world_size * local_split, device=device) - lengths_segment_sizes = recv_strides_t.flatten() - S_lengths = world_size * local_split - l_peer_block_sizes = [local_split] * world_size - else: - single_batch_per_rank = all(stride == stride_per_rank[0] for stride in stride_per_rank) - if single_batch_per_rank: - recat_lengths = _get_recat(local_split, world_size, stagger, device=device) - if recat_lengths is None: - recat_lengths = torch.arange(world_size * local_split, device=device) - lengths_segment_sizes = torch.full((world_size * local_split,), stride_per_rank[0], device=device, dtype=torch.long) - S_lengths = world_size * local_split - l_peer_block_sizes = [local_split] * world_size - else: - recat_lengths = _get_recat(local_split, world_size, stagger, device=device, batch_size_per_rank=stride_per_rank) - if recat_lengths is None: - total_len = sum(local_split * b for b in stride_per_rank) - recat_lengths = torch.arange(total_len, device=device) - S_lengths = len(recat_lengths) - lengths_segment_sizes = torch.ones(S_lengths, device=device, dtype=torch.long) - l_peer_block_sizes = [local_split * b for b in stride_per_rank] - - # Pre-calculate segment coordinates logically - l_peer_ids_unpermuted = torch.tensor( - [i for i, size in enumerate(l_peer_block_sizes) for _ in range(size)], - device=device, dtype=torch.int32 - ) - - l_global_offsets = torch.zeros(S_lengths + 1, dtype=torch.long, device=device) - l_global_offsets[1:] = torch.cumsum(lengths_segment_sizes, dim=0) - l_src_offsets_unpermuted = l_global_offsets[:-1].clone() - - base_idx = 0 - for i in range(world_size): - block_len = l_peer_block_sizes[i] - l_src_offsets_unpermuted[base_idx : base_idx + block_len] -= l_global_offsets[base_idx] - base_idx += block_len - - # Apply permutation to map segments to the output ordering - l_peer_ids = l_peer_ids_unpermuted[recat_lengths] - l_src_offsets = l_src_offsets_unpermuted[recat_lengths].to(torch.int32) - l_permuted_sizes = lengths_segment_sizes[recat_lengths] - l_dst_offsets = torch.zeros(S_lengths + 1, dtype=torch.int32, device=device) - l_dst_offsets[1:] = torch.cumsum(l_permuted_sizes, dim=0).to(torch.int32) - total_l = l_dst_offsets[-1].item() - - lengths_out = torch.empty(total_l, device=device, dtype=torch.int32) - - # 3. Stream data from peers using ParallelKittens - ext.tk_barrier(ws_barrier) - if total_l > 0: - ext.tk_pull_permute_int32( - ws_l, lengths_out, l_dst_offsets, l_peer_ids, l_src_offsets, - S_lengths, total_l, ws_l.data_.shape[-1], rank - ) - - # 4. Resolve Value Segment boundaries (driven purely by the fully-permuted lengths) - lengths_out_long = lengths_out.to(torch.long) - l_cumsum = torch.zeros(total_l + 1, dtype=torch.long, device=device) - l_cumsum[1:] = torch.cumsum(lengths_out_long, dim=0) - - v_permuted_sizes = l_cumsum[l_dst_offsets[1:].long()] - l_cumsum[l_dst_offsets[:-1].long()] - inverse_recat = torch.argsort(recat_lengths) - v_sizes_unpermuted = v_permuted_sizes[inverse_recat] - - v_global_offsets = torch.zeros(S_lengths + 1, dtype=torch.long, device=device) - v_global_offsets[1:] = torch.cumsum(v_sizes_unpermuted, dim=0) - v_src_offsets_unpermuted = v_global_offsets[:-1].clone() - - base_idx = 0 - for i in range(world_size): - block_len = l_peer_block_sizes[i] - v_src_offsets_unpermuted[base_idx : base_idx + block_len] -= v_global_offsets[base_idx] - base_idx += block_len - - v_src_offsets = v_src_offsets_unpermuted[recat_lengths].to(torch.int32) - v_dst_offsets = torch.zeros(S_lengths + 1, dtype=torch.int32, device=device) - v_dst_offsets[1:] = torch.cumsum(v_permuted_sizes, dim=0).to(torch.int32) - total_v = v_dst_offsets[-1].item() - - values_out = torch.empty(total_v, device=device, dtype=torch.bfloat16) - - if total_v > 0: - ext.tk_pull_permute_bf16( - ws_v, values_out, v_dst_offsets, l_peer_ids, v_src_offsets, - S_lengths, total_v, ws_v.data_.shape[-1], rank - ) - weights_out = None - if weights is not None: - weights_out = torch.empty(total_v, device=device, dtype=torch.bfloat16) - if total_v > 0: - ext.tk_pull_permute_bf16( - ws_w, weights_out, v_dst_offsets, l_peer_ids, v_src_offsets, - S_lengths, total_v, ws_w.data_.shape[-1], rank - ) - - result: Dict[str, torch.Tensor] = { - "lengths": lengths_out, - "values": values_out, - } - - if variable_stride: - stride_per_key_per_rank = recv_strides_t.T - if stagger > 1: - order = torch.arange(world_size, device=device).view(stagger, -1).T.reshape(-1) - stride_per_key_per_rank = stride_per_key_per_rank[:, order] - result["stride_per_key_per_rank"] = stride_per_key_per_rank - else: - result["stride"] = torch.tensor(sum(stride_per_rank), device=device) - result["stride_per_rank"] = torch.tensor(stride_per_rank, device=device) - - if weights is not None: - result["weights"] = weights_out - - return result \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_parallelkittens.py deleted file mode 100755 index 8cce31e..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_parallelkittens.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -ThunderKittens Reduce-Scatter via NVSwitch multimem. - -Strategy: -- We replace the NCCL `reduce_scatter_tensor` with a custom ThunderKittens kernel - that leverages Hopper NVSwitch `multimem::ld_reduce`. -- Each rank pushes its locally computed chunks to a contiguous symmetric buffer (`TKParallelTensor`). -- A single fused kernel executes a barrier, reads its assigned reduced chunk directly - from the multicast NVSwitch pointer (performing implicit hardware-accelerated summation - across all ranks), writes the result locally, and hits a final barrier. -- This entirely bypasses host-driven NCCL loops, operating at maximum NVSwitch bandwidth - with no extra memory traffic. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace reduce_scatter { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using input_layout = pgl, NUM_DEVICES, true>; - - input_layout input; - bf16* output_ptr; - const int dev_idx; - const int chunk_size; - - __host__ inline dim3 grid() const { - return dim3(chunk_size / NUM_ELEMS_PER_BLOCK); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t out_idx = globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - // Each device's chunk in the padded layout is at exactly G.chunk_size * G.dev_idx - const size_t in_idx = G.chunk_size * G.dev_idx + out_idx; - - bf16_2 tmp; - // Hardware multicast NVSwitch reduction across all ranks at this offset - multimem::ld_reduce(tmp, reinterpret_cast(&G.input.mc_ptr[in_idx])); - - // Write reduced data cleanly to local output tensor - *reinterpret_cast(&G.output_ptr[out_idx]) = tmp; -} - -} // namespace reduce_scatter - -namespace reduce_scatter_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace reduce_scatter_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(output, input, barrier); - - int chunk_size = output.data_.numel(); - - TORCH_CHECK(chunk_size % reduce_scatter::globals::NUM_ELEMS_PER_BLOCK == 0, - "The number of output tensor elements must be divisible by NUM_ELEMS_PER_BLOCK"); - TORCH_CHECK(input.data_.numel() == chunk_size * reduce_scatter::globals::NUM_DEVICES, - "Input tensor must be NUM_DEVICES times larger than output tensor"); - - reduce_scatter::globals reduce_scatter_G { - .input = kittens::py::parallel_tensor_to_pgl(input), - .output_ptr = reinterpret_cast(output.data_.data_ptr()), - .dev_idx = input.local_rank_, - .chunk_size = chunk_size - }; - - reduce_scatter_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - // Sync input readiness -> Reduce & Store locally -> Sync completion - kittens::py::launch_kernel(barrier_G); - kittens::py::launch_kernel(reduce_scatter_G); - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_reduce_scatter", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 # NUM_WARPGROUPS(2) * WARPGROUP_WARPS(4) * WARP_THREADS(32) -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT = NUM_ELEMS_PER_BLOCK # 512 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_reducescatter_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; got world_size={world}" - assert tensor.shape[0] % world == 0, f"First dimension ({tensor.shape[0]}) must be divisible by world_size ({world})" - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - chunk_dim0 = original_shape[0] // world - out_shape = (chunk_dim0,) + original_shape[1:] - original_dtype = tensor.dtype - - n = tensor.numel() - n_chunk = n // world - - # Pad chunk size to kernel alignment bounds (NUM_ELEMS_PER_BLOCK) - padded_n_chunk = ((n_chunk + ALIGNMENT - 1) // ALIGNMENT) * ALIGNMENT - - # Create / Fetch symmetric buffers for multicast group - input_tk = get_or_create_parallel_tensor(ext, (world, padded_n_chunk), torch.bfloat16, multicast=True) - output_tk = get_or_create_parallel_tensor(ext, (padded_n_chunk,), torch.bfloat16, multicast=False) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Format local data per-chunk and push to symmetric parallel tensor - flat_chunks = tensor.to(torch.bfloat16).reshape(world, n_chunk) - if padded_n_chunk > n_chunk: - padded_input = torch.zeros((world, padded_n_chunk), dtype=torch.bfloat16, device=tensor.device) - padded_input[:, :n_chunk] = flat_chunks - input_tk.data_.copy_(padded_input) - else: - input_tk.data_.copy_(flat_chunks) - - # Launch kernel sequence to NVSwitch reduction - ext.tk_reduce_scatter(output_tk, input_tk, barrier_tk) - - # Truncate any alignment padding and restore output shape characteristics - result = output_tk.data_[:n_chunk].clone() - return result.to(original_dtype).reshape(out_shape) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_parallelkittens.py deleted file mode 100755 index ff03022..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_parallelkittens.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -ThunderKittens all-to-all (personalized) via TMA between devices. - -Implements the same semantics as ``dist.all_to_all_single`` on a tensor of -shape ``[world_size, *chunk]``: -on rank ``r``, ``output[i, ...]`` holds the chunk rank ``i`` sent to ``r``. - -Uses ``scatter_axis=0``, ``gather_axis=1``: batch selects destination GPU, -depth gathers by source rank. Shapes are ``[W, 1, R, C]`` → ``[1, W, R, C]`` -with ``R=16`` and ``C`` a multiple of ``128``. - -Requires: ThunderKittens headers at ``$THUNDERKITTENS_ROOT/include``. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (all_to_all entrypoint + barrier) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" - -using namespace kittens; - -namespace all_to_all { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_THREADS = 1; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int ROW_BLOCK_SIZE = 16; - static constexpr int COL_BLOCK_SIZE = 128; - - using shared_tile = st_bf; - using parallel_layout = pgl, NUM_DEVICES, false>; - - parallel_layout output; - parallel_layout input; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3((input.cols() / globals::COL_BLOCK_SIZE) * - (input.rows() / globals::ROW_BLOCK_SIZE) * - input.depth() * input.batch()); - } - - __host__ inline int dynamic_shared_memory() const { - return static_cast(sizeof(shared_tile) + 1024); - } -}; - -template -__device__ inline void kernel(const globals &G) { - static_assert(0 <= SCATTER_AXIS && SCATTER_AXIS < 4 && 0 <= GATHER_AXIS && GATHER_AXIS < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - static_assert(SCATTER_AXIS != GATHER_AXIS, "Scatter and gather axes must be different"); - - extern __shared__ int __shm[]; - tma_swizzle_allocator allocator((int*)&__shm[0]); - globals::shared_tile &tile = allocator.allocate(); - - int task_idx = blockIdx.x; - int batch_idx = task_idx / (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.depth() * (G.input.rows() / globals::ROW_BLOCK_SIZE) * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int depth_idx = task_idx / (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - task_idx %= (G.input.rows() / globals::ROW_BLOCK_SIZE * (G.input.cols() / globals::COL_BLOCK_SIZE)); - int row_block_idx = task_idx / (G.input.cols() / globals::COL_BLOCK_SIZE); - task_idx %= (G.input.cols() / globals::COL_BLOCK_SIZE); - int col_block_idx = task_idx; - - __shared__ semaphore arrived; - init_semaphore(arrived, 0, 1); - tma::expect_bytes(arrived, sizeof(tile)); - tma::load_async(tile, G.input[G.dev_idx], {batch_idx, depth_idx, row_block_idx, col_block_idx}, arrived); - - int dst_dev_idx; - - if constexpr (SCATTER_AXIS == 0) { - dst_dev_idx = batch_idx / G.output.batch(); - batch_idx %= G.output.batch(); - } else if constexpr (SCATTER_AXIS == 1) { - dst_dev_idx = depth_idx / G.output.depth(); - depth_idx %= G.output.depth(); - } else if constexpr (SCATTER_AXIS == 2) { - dst_dev_idx = row_block_idx / (G.output.rows() / globals::ROW_BLOCK_SIZE); - row_block_idx %= (G.output.rows() / globals::ROW_BLOCK_SIZE); - } else { - dst_dev_idx = col_block_idx / (G.output.cols() / globals::COL_BLOCK_SIZE); - col_block_idx %= (G.output.cols() / globals::COL_BLOCK_SIZE); - } - - if constexpr (GATHER_AXIS == 0) { - batch_idx += G.input.batch() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 1) { - depth_idx += G.input.depth() * G.dev_idx; - } else if constexpr (GATHER_AXIS == 2) { - row_block_idx += (G.input.rows() / globals::ROW_BLOCK_SIZE) * G.dev_idx; - } else { - col_block_idx += (G.input.cols() / globals::COL_BLOCK_SIZE) * G.dev_idx; - } - - wait(arrived, 0); - tma::store_async(G.output[dst_dev_idx], tile, - {batch_idx, depth_idx, row_block_idx, col_block_idx}); -} - -} // namespace all_to_all - -namespace all_to_all_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_to_all_barrier - -void entrypoint( - kittens::py::TKParallelTensor &output, - kittens::py::TKParallelTensor &input, - kittens::py::TKParallelTensor &barrier, - int scatter_axis, - int gather_axis -) { - TORCH_CHECK(0 <= scatter_axis && scatter_axis < 4 && 0 <= gather_axis && gather_axis < 4, - "Scatter and gather axes must be 0, 1, 2, or 3"); - TORCH_CHECK(scatter_axis != gather_axis, "Scatter and gather axes must be different"); - - kittens::py::parallel_tensor_check(output, input); - - all_to_all::globals all_to_all_G { - .output = kittens::py::parallel_tensor_to_pgl(output), - .input = kittens::py::parallel_tensor_to_pgl(input), - .dev_idx = input.local_rank_ - }; - - all_to_all_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - - kittens::py::launch_kernel(barrier_G); - - if (scatter_axis == 0 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 0 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 1 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 2 && gather_axis == 3) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 0) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 1) - kittens::py::launch_kernel>(all_to_all_G); - else if (scatter_axis == 3 && gather_axis == 2) - kittens::py::launch_kernel>(all_to_all_G); - else - TORCH_CHECK(false, "Invalid scatter and gather axes"); - - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_all_to_all", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -ROW_TILE = 16 -COL_TILE = 128 -TILE_ELEMS = ROW_TILE * COL_TILE - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_alltoall_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - """Compile/load extension once; avoid per-call ``dist.barrier()`` in timed hot path.""" - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -def _padded_row_col(rest_elems: int) -> tuple[int, int, int]: - """Return (R, C, padded_rest) with R=16, C multiple of 128, R*C >= rest_elems.""" - num_tiles = (rest_elems + TILE_ELEMS - 1) // TILE_ELEMS - r, c = ROW_TILE, COL_TILE * num_tiles - padded = r * c - return r, c, padded - - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - assert tensor.is_cuda and tensor.is_contiguous() - - world = dist.get_world_size() - assert world == NUM_DEVICES, ( - f"This ThunderKittens kernel is built for NUM_DEVICES={NUM_DEVICES}; " - f"got world_size={world}" - ) - assert tensor.shape[0] == world, ( - f"First dimension ({tensor.shape[0]}) must equal world_size ({world})" - ) - - ext = _ensure_ext_jit() - - original_shape = tensor.shape - original_dtype = tensor.dtype - w = world - - # View input tensor iteratively into block-size chunks - flat = tensor.to(torch.bfloat16).reshape(w, -1).contiguous() - rest = flat.shape[1] - r, c, padded_rest = _padded_row_col(rest) - - # Pad data to align to Hopper TMA restrictions - padded = torch.zeros(w, padded_rest, dtype=torch.bfloat16, device=tensor.device) - padded[:, :rest] = flat - inp_4 = padded.view(w, 1, r, c) - - # Acquire parallel tensor IPC handles (cached across loops effectively avoiding repetitive memory bindings) - input_tk = get_or_create_parallel_tensor( - ext, (w, 1, r, c), torch.bfloat16, multicast=False - ) - output_tk = get_or_create_parallel_tensor( - ext, (1, w, r, c), torch.bfloat16, multicast=False - ) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Copy input into the memory-mapped PGL tensor bounds - n = inp_4.numel() - flat_in = inp_4.reshape(-1) - input_tk.data_.reshape(-1)[:n].copy_(flat_in) - - # Dispatch ThunderKittens Async TMA kernel natively overlapping collective - ext.tk_all_to_all(output_tk, input_tk, barrier_tk, 0, 1) - - # Format unpadded memory into final topology shape - out_flat = ( - output_tk.data_.reshape(-1)[:n].view(1, w, r, c)[0].reshape(w, padded_rest)[:, :rest].contiguous() - ) - return out_flat.reshape(original_shape).to(original_dtype) \ No newline at end of file diff --git a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_parallelkittens.py b/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_parallelkittens.py deleted file mode 100755 index 880dd81..0000000 --- a/solutions_parallelkittens_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_parallelkittens.py +++ /dev/null @@ -1,295 +0,0 @@ -""" -ThunderKittens Fused LayerNorm Backward Aggregation. - -Fuses the local token-wise sum over (dY) and (dY * X_hat) with a device-side -all-reduce using ThunderKittens PGL and NVSwitch multimem/multicast operations. -Optimized for H100 (sm_90a) BF16 workflows. - -Requires: ThunderKittens headers at $THUNDERKITTENS_ROOT/include. -""" - -import os -import torch -import torch.distributed as dist -from utils.cuda_helpers import compile_cuda_extension -from utils.parallelkittens_runtime import ( - get_or_create_barrier, - get_or_create_parallel_tensor, -) - -TK_ROOT = os.environ.get("THUNDERKITTENS_ROOT", "/opt/thunderkittens") - -# --------------------------------------------------------------------------- -# Embedded .cu source (Compute + Barriers + All-Reduce) -# --------------------------------------------------------------------------- -CUDA_SRC = r''' -#include "kittens.cuh" -#include "pyutils/torchutils.cuh" -#include - -using namespace kittens; - -namespace fused_ln_bwd { - -struct compute_globals { - const __nv_bfloat16* X_hat; - const __nv_bfloat16* dY; - __nv_bfloat16* local_out; - int B; - int H; - int padded_H; - int B_chunk; -}; - -__global__ void compute_kernel(const compute_globals G) { - int h = blockIdx.x * blockDim.x + threadIdx.x; - int b_start = blockIdx.y * G.B_chunk; - int b_end = min(b_start + G.B_chunk, G.B); - - if (h >= G.H) return; - - float d_gamma_sum = 0.0f; - float d_beta_sum = 0.0f; - - // Contiguous load: threads in a warp read consecutive h - for (int b = b_start; b < b_end; ++b) { - float x = __bfloat162float(G.X_hat[b * G.H + h]); - float dy = __bfloat162float(G.dY[b * G.H + h]); - d_gamma_sum += x * dy; - d_beta_sum += dy; - } - - if (G.B_chunk < G.B) { - atomicAdd(&G.local_out[h * 2], __float2bfloat16(d_gamma_sum)); - atomicAdd(&G.local_out[h * 2 + 1], __float2bfloat16(d_beta_sum)); - } else { - G.local_out[h * 2] = __float2bfloat16(d_gamma_sum); - G.local_out[h * 2 + 1] = __float2bfloat16(d_beta_sum); - } -} - -} // namespace fused_ln_bwd - -namespace all_reduce { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int MIN_BLOCKS_PER_SM = 8; - static constexpr int NUM_WARPGROUPS = 2; - static constexpr int NUM_WARPS = NUM_WARPGROUPS * WARPGROUP_WARPS; - static constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - static constexpr int NUM_ELEMS_PER_INST = 2; - static constexpr int NUM_ELEMS_PER_BLOCK = config::NUM_THREADS * NUM_ELEMS_PER_INST; - - using parallel_layout = pgl, NUM_DEVICES, true>; - - parallel_layout tensor; - const int dev_idx; - - __host__ inline dim3 grid() const { - return dim3(tensor.numel() / NUM_ELEMS_PER_BLOCK / NUM_DEVICES); - } -}; - -__device__ inline void kernel(const globals &G) { - const size_t N_total = G.tensor.numel(); - const size_t N_per_dev = N_total / globals::NUM_DEVICES; - const size_t idx = N_per_dev * G.dev_idx + - globals::NUM_ELEMS_PER_BLOCK * blockIdx.x + - globals::NUM_ELEMS_PER_INST * threadIdx.x; - - bf16_2 tmp; - multimem::ld_reduce(tmp, reinterpret_cast(&G.tensor.mc_ptr[idx])); - multimem::st(reinterpret_cast(&G.tensor.mc_ptr[idx]), tmp); -} - -} // namespace all_reduce - -namespace all_reduce_barrier { - -struct config { - static constexpr int CLUSTER_SIZE = 1; - static constexpr int NUM_BLOCKS = 1; - static constexpr int NUM_THREADS = 1; - static constexpr int DYNAMIC_SHARED_MEMORY = 0; -}; - -struct globals { - static constexpr int NUM_DEVICES = 8; - barrier_t barrier; - const int dev_idx; -}; - -__device__ inline void kernel(const globals &G) { - barrier_all(G.barrier, {0}, G.dev_idx); -} - -} // namespace all_reduce_barrier - -void entrypoint( - torch::Tensor X_hat, - torch::Tensor dY, - kittens::py::TKParallelTensor &tensor, - kittens::py::TKParallelTensor &barrier -) { - kittens::py::parallel_tensor_check(tensor, barrier); - - int B = X_hat.size(0); - int H = X_hat.size(1); - - TORCH_CHECK(tensor.data_.numel() % (all_reduce::globals::NUM_DEVICES * all_reduce::globals::NUM_ELEMS_PER_BLOCK) == 0, - "The total number of tensor elements must be divisible by NUM_DEVICES * NUM_ELEMS_PER_BLOCK"); - - // 1. Launch compute kernel to reduce tokens -> feature gradients directly into PGL memory - int threads_per_block = 256; - int blocks_x = (H + threads_per_block - 1) / threads_per_block; - - int grid_y = 16; - int B_chunk = (B + grid_y - 1) / grid_y; - if (B_chunk == 0) B_chunk = 1; - grid_y = (B + B_chunk - 1) / B_chunk; - - dim3 grid(blocks_x, grid_y); - dim3 block(threads_per_block); - - fused_ln_bwd::compute_globals compute_G { - .X_hat = reinterpret_cast(X_hat.data_ptr()), - .dY = reinterpret_cast(dY.data_ptr()), - .local_out = reinterpret_cast<__nv_bfloat16*>(tensor.data_.data_ptr()), - .B = B, - .H = H, - .padded_H = static_cast(tensor.data_.size(0)), - .B_chunk = B_chunk - }; - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - fused_ln_bwd::compute_kernel<<>>(compute_G); - - // 2. Local-sync barrier - all_reduce_barrier::globals barrier_G { - .barrier = kittens::py::parallel_tensor_to_pgl>(barrier), - .dev_idx = barrier.local_rank_ - }; - kittens::py::launch_kernel(barrier_G); - - // 3. HW-accelerated cross-GPU all-reduce - all_reduce::globals all_reduce_G { - .tensor = kittens::py::parallel_tensor_to_pgl(tensor), - .dev_idx = tensor.local_rank_ - }; - kittens::py::launch_kernel(all_reduce_G); - - // 4. Global-sync barrier - kittens::py::launch_kernel(barrier_G); -} - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - BIND_TK_PARALLEL_TENSOR(m); - m.def("tk_fused_ln_bwd", &entrypoint); -} -''' - -TK_CUDA_FLAGS = [ - "-std=c++20", - "--use_fast_math", - "--expt-extended-lambda", - "--expt-relaxed-constexpr", - "-DKITTENS_HOPPER", - "-gencode", "arch=compute_90a,code=sm_90a", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "-Xcompiler=-Wno-psabi", - "-Xcompiler=-fno-strict-aliasing", - "-DNDEBUG", -] - -_ext = None -_ext_jit_ready = False - -NUM_DEVICES = 8 -NUM_THREADS = 256 -NUM_ELEMS_PER_INST = 2 -NUM_ELEMS_PER_BLOCK = NUM_THREADS * NUM_ELEMS_PER_INST -ALIGNMENT_ELEMS = NUM_DEVICES * NUM_ELEMS_PER_BLOCK # 4096 elements -ALIGNMENT_H = ALIGNMENT_ELEMS // 2 # 2048 - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension( - "tk_fused_ln_bwd_ext", - CUDA_SRC, - extra_cuda_cflags=TK_CUDA_FLAGS, - extra_include_paths=[ - os.path.join(TK_ROOT, "include"), - os.path.join(TK_ROOT, "prototype"), - ], - extra_ldflags=["-lcuda"], - ) - return _ext - - -def _ensure_ext_jit(): - global _ext_jit_ready - if _ext_jit_ready: - return _get_ext() - rank = dist.get_rank() - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - _ext_jit_ready = True - return ext - - -@torch.no_grad() -def solution( - X_hat: torch.Tensor, - dY: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert X_hat.is_cuda and dY.is_cuda, "Inputs must be CUDA tensors" - assert X_hat.is_contiguous() and dY.is_contiguous(), "Inputs must be contiguous" - assert X_hat.shape == dY.shape, "X_hat and dY must have the same shape" - - world = dist.get_world_size() - assert world == NUM_DEVICES, f"Expected NUM_DEVICES={NUM_DEVICES}, got {world}" - - B, H = X_hat.shape - ext = _ensure_ext_jit() - - original_dtype = X_hat.dtype - - # Hardware target fastpath precision - X_bf16 = X_hat.to(torch.bfloat16) - dY_bf16 = dY.to(torch.bfloat16) - - # Output dimensions padded to satisfy the NVSwitch alignment (NUM_DEVICES * 512 elements) - padded_H = ((H + ALIGNMENT_H - 1) // ALIGNMENT_H) * ALIGNMENT_H - - # Prepare symmetrical device buffers ([padded_H, 2] multiplexes d_gamma and d_beta) - tensor_tk = get_or_create_parallel_tensor(ext, (padded_H, 2), torch.bfloat16, multicast=True) - barrier_tk = get_or_create_barrier(ext, num_devices=world) - - # Erase padding and prev outputs to prep `atomicAdd` aggregation - tensor_tk.data_.zero_() - - # Invoke fused device pipeline (Compute -> Barrier -> Multimem LD/ST -> Barrier) - ext.tk_fused_ln_bwd(X_bf16, dY_bf16, tensor_tk, barrier_tk) - - # Pluck original shape out of the padded symmetrical buffer - out = tensor_tk.data_[:H, :] - d_gamma = out[:, 0].clone() - d_beta = out[:, 1].clone() - - return d_gamma.to(original_dtype), d_beta.to(original_dtype) \ No newline at end of file