diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_cuda.py deleted file mode 100755 index 32241e9..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/10_embedding_lookup_cuda.py +++ /dev/null @@ -1,257 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -__global__ void lookup_scan_all_vec_kernel( - const int64_t* __restrict__ ptrs_meta_arr, - const int64_t* __restrict__ ptrs_queries_arr, - const int64_t* __restrict__ ptrs_out_arr, - const T* __restrict__ local_shard, - int my_rank, - int world_size, - int64_t shard_size, - int64_t embed_dim, - int vec_size -) { - // 2D Grid: blockIdx.y = target remote rank, blockIdx.x = query chunk - int target_rank = blockIdx.y; - - // Read the query count (N_A) for the remote rank we are inspecting - const int64_t* meta_A = reinterpret_cast(ptrs_meta_arr[target_rank]); - int64_t N_A = meta_A[0]; - - const int64_t* queries_A = reinterpret_cast(ptrs_queries_arr[target_rank]); - T* out_A = reinterpret_cast(ptrs_out_arr[target_rank]); - - int num_blocks_x = gridDim.x; - int block_x = blockIdx.x; - - if (vec_size == 8 && embed_dim % 8 == 0) { - int vec_dim = embed_dim / 8; - const uint4* local_shard_vec = reinterpret_cast(local_shard); - uint4* out_A_vec = reinterpret_cast(out_A); - - for (int64_t q = block_x; q < N_A; q += num_blocks_x) { - int64_t global_idx = queries_A[q]; - // Mimic Python's floor division `//` for negative indices - int64_t target = global_idx >= 0 ? (global_idx / shard_size) : ((global_idx - shard_size + 1) / shard_size); - - if (target == my_rank) { - int64_t local_offset = global_idx - my_rank * shard_size; - // Clamp to safe range, analogous to torch.clamp(..., 0, shard_size - 1) - if (local_offset < 0) local_offset = 0; - if (local_offset >= shard_size) local_offset = shard_size - 1; - - // Vectorized peer-to-peer write - for (int d = threadIdx.x; d < vec_dim; d += blockDim.x) { - out_A_vec[q * vec_dim + d] = local_shard_vec[local_offset * vec_dim + d]; - } - } - } - } else if (vec_size == 4 && embed_dim % 4 == 0) { - int vec_dim = embed_dim / 4; - const uint2* local_shard_vec = reinterpret_cast(local_shard); - uint2* out_A_vec = reinterpret_cast(out_A); - - for (int64_t q = block_x; q < N_A; q += num_blocks_x) { - int64_t global_idx = queries_A[q]; - int64_t target = global_idx >= 0 ? (global_idx / shard_size) : ((global_idx - shard_size + 1) / shard_size); - - if (target == my_rank) { - int64_t local_offset = global_idx - my_rank * shard_size; - if (local_offset < 0) local_offset = 0; - if (local_offset >= shard_size) local_offset = shard_size - 1; - - for (int d = threadIdx.x; d < vec_dim; d += blockDim.x) { - out_A_vec[q * vec_dim + d] = local_shard_vec[local_offset * vec_dim + d]; - } - } - } - } else { - // Scalar fallback - for (int64_t q = block_x; q < N_A; q += num_blocks_x) { - int64_t global_idx = queries_A[q]; - int64_t target = global_idx >= 0 ? (global_idx / shard_size) : ((global_idx - shard_size + 1) / shard_size); - - if (target == my_rank) { - int64_t local_offset = global_idx - my_rank * shard_size; - if (local_offset < 0) local_offset = 0; - if (local_offset >= shard_size) local_offset = shard_size - 1; - - for (int d = threadIdx.x; d < embed_dim; d += blockDim.x) { - out_A[q * embed_dim + d] = local_shard[local_offset * embed_dim + d]; - } - } - } - } -} - -void launch_lookup( - torch::Tensor ptrs_meta, - torch::Tensor ptrs_queries, - torch::Tensor ptrs_out, - torch::Tensor local_shard, - int rank, - int world_size, - int64_t shard_size, - int64_t embed_dim -) { - int vec_size = 1; - if (embed_dim % 8 == 0) vec_size = 8; - else if (embed_dim % 4 == 0) vec_size = 4; - - // Distribute queries dynamically across robust grid of 1024 chunks - int num_blocks_x = 1024; - dim3 grid(num_blocks_x, world_size); - dim3 block(128); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const int64_t* p_meta = ptrs_meta.data_ptr(); - const int64_t* p_queries = ptrs_queries.data_ptr(); - const int64_t* p_out = ptrs_out.data_ptr(); - - if (local_shard.scalar_type() == torch::kBFloat16) { - lookup_scan_all_vec_kernel<__nv_bfloat16><<>>( - p_meta, p_queries, p_out, - reinterpret_cast(local_shard.data_ptr()), - rank, world_size, shard_size, embed_dim, vec_size - ); - } else if (local_shard.scalar_type() == torch::kHalf) { - lookup_scan_all_vec_kernel<__half><<>>( - p_meta, p_queries, p_out, - reinterpret_cast(local_shard.data_ptr()), - rank, world_size, shard_size, embed_dim, vec_size - ); - } else if (local_shard.scalar_type() == torch::kFloat32) { - lookup_scan_all_vec_kernel<<>>( - p_meta, p_queries, p_out, - reinterpret_cast(local_shard.data_ptr()), - rank, world_size, shard_size, embed_dim, vec_size - ); - } else { - TORCH_CHECK(false, "Unsupported dtype. Valid types: float32, float16, bfloat16."); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_lookup", &launch_lookup, "UVA P2P scan-all embedding lookup"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("uva_lookup_scan_all_ext", CUDA_SRC) - return _ext - - -class SymmState: - def __init__(self): - # We pre-allocate a generous symmetric buffer (4M queries maximum per rank) to avoid - # any host-device synchronization or collectives strictly on the hot path. - self.capacity = 4194304 - self.embed_dim = -1 - self.dtype = None - self.buf_meta = None - self.hdl_meta = None - self.ptrs_meta = None - self.buf_queries = None - self.hdl_queries = None - self.ptrs_queries = None - self.buf_out = None - self.hdl_out = None - self.ptrs_out = None - self.is_initialized = False - - -_STATE = SymmState() - -def _ensure_initialized(N: int, embed_dim: int, dtype: torch.dtype, device: torch.device): - global _STATE - if _STATE.is_initialized: - return - - _STATE.embed_dim = embed_dim - _STATE.dtype = dtype - - # meta buffer: 1 element to communicate local `N` without host-driven collective - _STATE.buf_meta = symm_mem.empty((1,), dtype=torch.int64, device=device) - _STATE.hdl_meta = symm_mem.rendezvous(_STATE.buf_meta, dist.group.WORLD) - _STATE.ptrs_meta = torch.tensor(_STATE.hdl_meta.buffer_ptrs, dtype=torch.int64, device=device) - - # queries buffer (max `self.capacity` elements) - _STATE.buf_queries = symm_mem.empty((_STATE.capacity,), dtype=torch.int64, device=device) - _STATE.hdl_queries = symm_mem.rendezvous(_STATE.buf_queries, dist.group.WORLD) - _STATE.ptrs_queries = torch.tensor(_STATE.hdl_queries.buffer_ptrs, dtype=torch.int64, device=device) - - # peer output buffer - _STATE.buf_out = symm_mem.empty((_STATE.capacity, embed_dim), dtype=dtype, device=device) - _STATE.hdl_out = symm_mem.rendezvous(_STATE.buf_out, dist.group.WORLD) - _STATE.ptrs_out = torch.tensor(_STATE.hdl_out.buffer_ptrs, dtype=torch.int64, device=device) - - _STATE.is_initialized = True - - -@torch.no_grad() -def solution(indices: torch.Tensor, local_shard: torch.Tensor) -> torch.Tensor: - if not dist.is_initialized(): - return local_shard[torch.clamp(indices, 0, local_shard.shape[0]-1)] - - rank = dist.get_rank() - world_size = dist.get_world_size() - shard_size = local_shard.shape[0] - embed_dim = local_shard.shape[1] - - indices = indices.contiguous() - local_shard = local_shard.contiguous() - N = indices.numel() - - if rank == 0: - _get_ext() - dist.barrier() - _get_ext() - - _ensure_initialized(N, embed_dim, local_shard.dtype, indices.device) - assert N <= _STATE.capacity, f"Query count (N={N}) exceeds generous symmetric capacity limit ({_STATE.capacity})" - - # 1. Provide `N` to meta mapping and stage contiguous queries. - # fill_() / copy_() occur completely asynchronously on the host stream. - _STATE.buf_meta.fill_(N) - if N > 0: - _STATE.buf_queries[:N].copy_(indices) - - # 2. Wait for all ranks to complete uploading queries and configurations. - _STATE.hdl_out.barrier(channel=0) - - # 3. Each rank executes peer evaluation natively pushing resolving elements cleanly backward. - _get_ext().launch_lookup( - _STATE.ptrs_meta, - _STATE.ptrs_queries, - _STATE.ptrs_out, - local_shard, - rank, - world_size, - shard_size, - embed_dim - ) - - # 4. Stream safety wait to seal operations across all peer device actions. - _STATE.hdl_out.barrier(channel=0) - - # 5. Provide discrete slice output tensor mimicking standard isolation behavior. - if N > 0: - output_vectors = _STATE.buf_out[:N].clone() - else: - output_vectors = torch.empty((0, embed_dim), dtype=local_shard.dtype, device=indices.device) - - return output_vectors \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_cuda.py deleted file mode 100755 index e742ef5..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/11_gemm_allgather_AT_cuda.py +++ /dev/null @@ -1,251 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Vectorized kernel: reads 8 bfloat16 elements (16 bytes) at once over NVLink -__global__ void pull_a_kernel_vec( - const uint64_t* __restrict__ peer_ptrs, - __nv_bfloat16* __restrict__ out, - int64_t start_row, - int64_t num_rows, - int64_t K_local, - int world_size -) { - int peer_idx = blockIdx.y; - const __nv_bfloat16* src = reinterpret_cast(peer_ptrs[peer_idx]); - - int64_t K_global = K_local * world_size; - int64_t vec_K = K_local / 8; - - int64_t col_vec_start = blockIdx.z * blockDim.x + threadIdx.x; - int64_t row_start = blockIdx.x * blockDim.y + threadIdx.y; - - int64_t row_stride = gridDim.x * blockDim.y; - int64_t col_stride = gridDim.z * blockDim.x; - - for (int64_t row = row_start; row < num_rows; row += row_stride) { - int64_t src_row = start_row + row; - int64_t out_row_offset = row * K_global + peer_idx * K_local; - int64_t src_row_offset = src_row * K_local; - - for (int64_t col_vec = col_vec_start; col_vec < vec_K; col_vec += col_stride) { - int64_t src_idx = src_row_offset + col_vec * 8; - int64_t out_idx = out_row_offset + col_vec * 8; - - float4 val = *reinterpret_cast(&src[src_idx]); - *reinterpret_cast(&out[out_idx]) = val; - } - } -} - -// Scalar kernel fallback for dimensions not perfectly divisible by 8 -__global__ void pull_a_kernel_scalar( - const uint64_t* __restrict__ peer_ptrs, - __nv_bfloat16* __restrict__ out, - int64_t start_row, - int64_t num_rows, - int64_t K_local, - int world_size -) { - int peer_idx = blockIdx.y; - const __nv_bfloat16* src = reinterpret_cast(peer_ptrs[peer_idx]); - - int64_t K_global = K_local * world_size; - - int64_t col_start = blockIdx.z * blockDim.x + threadIdx.x; - int64_t row_start = blockIdx.x * blockDim.y + threadIdx.y; - - int64_t row_stride = gridDim.x * blockDim.y; - int64_t col_stride = gridDim.z * blockDim.x; - - for (int64_t row = row_start; row < num_rows; row += row_stride) { - int64_t src_row = start_row + row; - int64_t out_row_offset = row * K_global + peer_idx * K_local; - int64_t src_row_offset = src_row * K_local; - - for (int64_t col = col_start; col < K_local; col += col_stride) { - int64_t src_idx = src_row_offset + col; - int64_t out_idx = out_row_offset + col; - - out[out_idx] = src[src_idx]; - } - } -} - -void launch_pull_a( - torch::Tensor peer_ptrs_tensor, - torch::Tensor out, - int64_t start_row, - int64_t num_rows, - int64_t K_local, - int world_size, - int64_t stream_ptr -) { - const uint64_t* peer_ptrs = reinterpret_cast(peer_ptrs_tensor.data_ptr()); - __nv_bfloat16* out_ptr = reinterpret_cast<__nv_bfloat16*>(out.data_ptr()); - cudaStream_t stream = reinterpret_cast(stream_ptr); - - // We intentionally cap grid dimensions to only use ~10-20% of the device's SMs. - // This leaves the majority of SMs completely free for overlapping Tensor Core matrix math! - if (K_local % 8 == 0) { - dim3 threads(32, 8); - int64_t vec_K = K_local / 8; - - int blocks_x = std::min((num_rows + threads.y - 1) / threads.y, 8); - int blocks_z = std::min((vec_K + threads.x - 1) / threads.x, 4); - dim3 blocks(blocks_x, world_size, blocks_z); - - pull_a_kernel_vec<<>>( - peer_ptrs, out_ptr, start_row, num_rows, K_local, world_size - ); - } else { - dim3 threads(32, 8); - int blocks_x = std::min((num_rows + threads.y - 1) / threads.y, 8); - int blocks_z = std::min((K_local + threads.x - 1) / threads.x, 16); - dim3 blocks(blocks_x, world_size, blocks_z); - - pull_a_kernel_scalar<<>>( - peer_ptrs, out_ptr, start_row, num_rows, K_local, world_size - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pull_a", &launch_pull_a, "Pull A chunks from peers over NVLink"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("pull_a_allgather_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(shape, dtype, device): - key = (shape, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, ptrs) - _symm_cache[key] = res - return res - - -@torch.no_grad() -def solution( - A_local: torch.Tensor, - B: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B.is_cuda, "Inputs must be CUDA tensors" - - if A_local.dtype != torch.bfloat16: - # Fallback to reference NCCL implementation for arbitrary unsupported types - world_size = dist.get_world_size() - M, K_local = A_local.shape - K_global = world_size * K_local - - A_local_t = A_local.transpose(0, 1).contiguous() - A_t_gather = torch.empty((world_size, K_local, M), device=A_local.device, dtype=A_local.dtype) - dist.all_gather_into_tensor(A_t_gather, A_local_t) - A_global_t = A_t_gather.reshape(K_global, M) - - B = B.contiguous() - C_t = torch.matmul(B.transpose(0, 1), A_global_t) - return C_t.transpose(0, 1) - - global _ext - if _ext is None: - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - - world_size = dist.get_world_size() - M, K_local = A_local.shape - K_global = world_size * K_local - N = B.shape[1] - - A_local = A_local.contiguous() - B = B.contiguous() - - # Expose current rank's input slice via symm_mem - buf, hdl, peer_ptrs = _get_symm_state(A_local.shape, A_local.dtype, A_local.device) - buf.copy_(A_local) - - # Determine optimized pipeline chunk size (prioritizing max tensor-core efficiency limits) - chunk_size = 2048 - if M <= chunk_size: - num_chunks = 1 - chunk_size = M - else: - num_chunks = (M + chunk_size - 1) // chunk_size - if num_chunks > 8: - num_chunks = 8 - chunk_size = (M + num_chunks - 1) // num_chunks - - # Double buffers for pipeline overlap - A_global_0 = torch.empty((chunk_size, K_global), device=A_local.device, dtype=A_local.dtype) - A_global_1 = torch.empty((chunk_size, K_global), device=A_local.device, dtype=A_local.dtype) - C = torch.empty((M, N), device=A_local.device, dtype=A_local.dtype) - - stream_0 = torch.cuda.Stream() - stream_1 = torch.cuda.Stream() - main_stream = torch.cuda.current_stream() - - # Await device-side exposure of remote symm_mems before streams begin NVLink reading - hdl.barrier(channel=0) - stream_0.wait_stream(main_stream) - stream_1.wait_stream(main_stream) - - for i in range(num_chunks): - start_row = i * chunk_size - end_row = min(start_row + chunk_size, M) - actual_M = end_row - start_row - - is_even = (i % 2 == 0) - stream_curr = stream_0 if is_even else stream_1 - buf_curr = A_global_0 if is_even else A_global_1 - - if i == 0: - with torch.cuda.stream(stream_curr): - _get_ext().launch_pull_a( - peer_ptrs, buf_curr, start_row, actual_M, K_local, world_size, stream_curr.cuda_stream - ) - - # Tensor core cuBLAS execution handles chunk i (same stream naturally enforces intra-chunk dependency) - with torch.cuda.stream(stream_curr): - out_slice = C[start_row:end_row] - torch.mm(buf_curr[:actual_M], B, out=out_slice) - - # Push the NEXT chunk's NVLink memory pull to the alternate, concurrently executing stream - if i + 1 < num_chunks: - next_start = (i + 1) * chunk_size - next_actual = min(next_start + chunk_size, M) - next_start - stream_next = stream_1 if is_even else stream_0 - buf_next = A_global_1 if is_even else A_global_0 - - with torch.cuda.stream(stream_next): - _get_ext().launch_pull_a( - peer_ptrs, buf_next, next_start, next_actual, K_local, world_size, stream_next.cuda_stream - ) - - main_stream.wait_stream(stream_0) - main_stream.wait_stream(stream_1) - - # Ensure no rank destructs/re-enters and alters `buf` while a peer continues chunk retrieval - hdl.barrier(channel=0) - - return C \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_cuda.py deleted file mode 100755 index 7a80c15..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/12_gemm_allgather_cuda.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Strategy: -- **Device-side communication & fusion**: Bypasses `dist.all_gather` and the subsequent `torch.cat`. Uses symmetric memory and a custom CUDA kernel that directly pulls `A_local` shards from all peers over NVLink and natively writes them into their final contiguous positions in `A_global`. -- **Compute-communication overlap**: Partitions the M-dimension of the GEMM into chunks. While the main stream computes the dense Tensor Core matmul for chunk `c`, a concurrent CUDA stream pulls the symmetric memory shards for chunk `c+1`. This perfectly hides communication latency without destroying arithmetic intensity (which would happen if we partitioned along K). -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void gather_concat_kernel_row_peer( - const uint64_t* __restrict__ peer_ptrs, - T* __restrict__ out, - int M_chunk, - int K_local, - int world_size, - int64_t peer_offset_elements -) { - int total_tasks = M_chunk * world_size; - int task_idx = blockIdx.x * blockDim.y + threadIdx.y; - - if (task_idx < total_tasks) { - int row = task_idx / world_size; - int peer = task_idx % world_size; - - const T* peer_buf = reinterpret_cast(peer_ptrs[peer]) + peer_offset_elements; - T* out_row = out + row * (K_local * world_size) + peer * K_local; - const T* in_row = peer_buf + row * K_local; - - for (int col = threadIdx.x; col < K_local; col += blockDim.x) { - out_row[col] = in_row[col]; - } - } -} - -__global__ void gather_concat_kernel_128_row_peer( - const uint64_t* __restrict__ peer_ptrs, - int4* __restrict__ out, - int M_chunk, - int K_local_128, - int world_size, - int64_t peer_offset_128 -) { - int total_tasks = M_chunk * world_size; - int task_idx = blockIdx.x * blockDim.y + threadIdx.y; - - if (task_idx < total_tasks) { - int row = task_idx / world_size; - int peer = task_idx % world_size; - - const int4* peer_buf = reinterpret_cast(peer_ptrs[peer]) + peer_offset_128; - int4* out_row = out + row * (K_local_128 * world_size) + peer * K_local_128; - const int4* in_row = peer_buf + row * K_local_128; - - for (int col = threadIdx.x; col < K_local_128; col += blockDim.x) { - out_row[col] = in_row[col]; - } - } -} - -void launch_gather_concat( - torch::Tensor peer_ptrs_tensor, - torch::Tensor out, - int M_chunk, - int K_local, - int world_size, - int element_size, - int64_t peer_offset_elements -) { - TORCH_CHECK(out.is_contiguous(), "out must be contiguous"); - const uint64_t* d_ptrs = reinterpret_cast(peer_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - int64_t total_bytes = (int64_t)M_chunk * K_local * world_size * element_size; - - if (total_bytes % 16 == 0 && (K_local * element_size) % 16 == 0 && (peer_offset_elements * element_size) % 16 == 0) { - int K_local_128 = (K_local * element_size) / 16; - int64_t peer_offset_128 = (peer_offset_elements * element_size) / 16; - - dim3 block(32, 8); // 32 threads for inner copy, 8 parallel tasks - int total_tasks = M_chunk * world_size; - int grid = (total_tasks + block.y - 1) / block.y; - - gather_concat_kernel_128_row_peer<<>>( - d_ptrs, reinterpret_cast(out.data_ptr()), M_chunk, K_local_128, world_size, peer_offset_128); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - dim3 block(32, 8); - int total_tasks = M_chunk * world_size; - int grid = (total_tasks + block.y - 1) / block.y; - - if (element_size == 2) { - gather_concat_kernel_row_peer<<>>( - d_ptrs, reinterpret_cast(out.data_ptr()), M_chunk, K_local, world_size, peer_offset_elements); - } else if (element_size == 4) { - gather_concat_kernel_row_peer<<>>( - d_ptrs, reinterpret_cast(out.data_ptr()), M_chunk, K_local, world_size, peer_offset_elements); - } else { - gather_concat_kernel_row_peer<<>>( - d_ptrs, reinterpret_cast(out.data_ptr()), M_chunk * element_size, K_local * element_size, world_size, peer_offset_elements * element_size); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather_concat", &launch_gather_concat, "P2P fetch and concatenate chunks"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("p2p_gather_concat_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(n: int, dtype: torch.dtype, device: torch.device): - global _symm_cache - key = (n, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(n, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return _symm_cache[key] - -@torch.no_grad() -def solution(A_local: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert A_local.is_cuda and B.is_cuda, "Inputs must be CUDA tensors" - - world_size = dist.get_world_size() - - M, K_local = A_local.shape - K_B, N = B.shape - K_global = world_size * K_local - assert K_B == K_global, f"B must have K dimension = world_size * K_local" - - # Pre-compile - _get_ext() - - # Expose A_local layout to peers - buf, hdl, ptrs_tensor = _get_symm_state(M * K_local, A_local.dtype, A_local.device) - - # Copy data synchronously to symm_mem so peers can safely pull it - buf.copy_(A_local.contiguous().flatten()) - hdl.barrier(channel=0) - - # Barrier completes on current stream. Create an event so the comm stream knows it's safe to start pulling. - comp_stream = torch.cuda.current_stream() - sync_event = torch.cuda.Event() - sync_event.record(comp_stream) - - comm_stream = torch.cuda.Stream() - comm_stream.wait_event(sync_event) - - # Partition M intelligently into chunks to overlap computation with communication - NUM_CHUNKS = 1 - if M % 4 == 0 and (M // 4) >= 128: - NUM_CHUNKS = 4 - elif M % 2 == 0 and (M // 2) >= 128: - NUM_CHUNKS = 2 - - M_chunk = M // NUM_CHUNKS - - # Pre-allocate fully materialized targets - A_global = torch.empty((M, K_global), dtype=A_local.dtype, device=A_local.device) - C_out = torch.empty((M, N), dtype=A_local.dtype, device=A_local.device) - - comm_events = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - - # Launch purely pipelined communication asynchronously - for c in range(NUM_CHUNKS): - with torch.cuda.stream(comm_stream): - offset = c * M_chunk * K_local - out_slice = A_global[c * M_chunk : (c + 1) * M_chunk, :] - _get_ext().launch_gather_concat( - ptrs_tensor, - out_slice, - M_chunk, - K_local, - world_size, - A_local.element_size(), - offset - ) - comm_events[c].record(comm_stream) - - # Launch pipelined compute as chunks become available - for c in range(NUM_CHUNKS): - # Wait for the chunk's communication to complete - comm_events[c].wait(comp_stream) - out_slice = A_global[c * M_chunk : (c + 1) * M_chunk, :] - C_chunk = C_out[c * M_chunk : (c + 1) * M_chunk, :] - torch.matmul(out_slice, B.contiguous(), out=C_chunk) - - # Prevent successive steps / functions from un-registering or overwriting active buffers - dist.barrier() - - return C_out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_cuda.py deleted file mode 100755 index 9535529..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/13_gemm_allreduce_cuda.py +++ /dev/null @@ -1,372 +0,0 @@ -""" -Strategy: -1. **Device-side Communication**: Uses Hopper NVSwitch `multimem.ld_reduce` and `multimem.st` instructions to perform an in-network fused all-reduce broadcast directly on symmetric bfloat16 tensors. A fast P2P pointer-based fallback is implemented for other dtypes. -2. **Compute-Communication Overlap**: The GEMM is chunked along the M dimension. The `torch.matmul` computation (cuBLAS) is launched on the default stream, while the chunked all-reduce kernels operate asynchronously on a dedicated communication stream. -3. **Pipelining**: A reusable custom device-side blockwise barrier inside symmetric memory ensures safe, chunk-level synchronization without host intervention. This perfectly hides chunk $i$'s all-reduce latency behind chunk $i+1$'s GEMM computation. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Reusable blockwise barrier in Symmetric Memory -// --------------------------------------------------------------------------- -__device__ void blockwise_barrier_reusable( - const uint64_t* __restrict__ sync_ptrs, - uint64_t barrier_idx, - uint64_t chunk_id, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - - uint64_t local_base = sync_ptrs[rank]; - uint64_t remote_base = sync_ptrs[flat_tid]; - - // offset in uint32_t elements - // MAX_CHUNKS = 128, MAX_BLOCKS = 32 - uint64_t offset = (((barrier_idx * 128 + chunk_id) * gridDim.x) + block_id) * world_size; - - uint32_t* send_addr = reinterpret_cast(remote_base) + offset + rank; - uint32_t* wait_addr = reinterpret_cast(local_base) + offset + flat_tid; - - uint32_t tmp; - // Send signal (self-cleans from 0 to 1) - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(send_addr) - : "memory"); - } while (tmp != 0u); - - // Wait signal and reset (self-cleans from 1 back to 0) - do { - asm volatile( - "atom.global.acquire.sys.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(wait_addr) - : "memory"); - } while (tmp != 1u); -} - -// --------------------------------------------------------------------------- -// Multimem chunked all-reduce (BF16) -// --------------------------------------------------------------------------- -__global__ void multimem_allreduce_bf16_chunked_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ sync_ptrs, - int64_t chunk_offset_128, - int64_t chunk_numel_128, - int world_size, - int rank, - int block_stride, - int chunk_id -) { - const uint64_t block_id = static_cast(blockIdx.x); - - // Wait for all ranks to complete GEMM for this chunk - blockwise_barrier_reusable(sync_ptrs, 0, chunk_id, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (chunk_numel_128 + world_size - 1) / world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = block_id * block_stride; - block_start < numel_per_rank; - block_start += num_programs * block_stride) - { - const int64_t offsets = block_start + tid; - if (offsets >= numel_per_rank) continue; - - const int64_t idx = rank * numel_per_rank + offsets; - if (idx < chunk_numel_128) { - uint64_t* ptrs = reinterpret_cast(multicast_base) + (chunk_offset_128 + idx) * 2; - uint32_t x, y, z, w; - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(x), "=r"(y), "=r"(z), "=r"(w) - : "l"(ptrs) - : "memory"); - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(ptrs), "r"(x), "r"(y), "r"(z), "r"(w) - : "memory"); - } - } - - __syncthreads(); - // Barrier to ensure no early exits before multimem accesses complete - blockwise_barrier_reusable(sync_ptrs, 1, chunk_id, block_id, rank, world_size); -} - -// --------------------------------------------------------------------------- -// P2P chunked all-reduce (Fallback) -// --------------------------------------------------------------------------- -__global__ void p2p_allreduce_chunked_kernel( - const uint64_t* __restrict__ sync_ptrs, - const long long* __restrict__ ptrs, - int64_t chunk_offset, - int64_t chunk_numel, - int world_size, - int rank, - int chunk_id, - int dtype_enum -) { - const uint64_t block_id = blockIdx.x; - - // Wait for all ranks to complete GEMM for this chunk - blockwise_barrier_reusable(sync_ptrs, 0, chunk_id, block_id, rank, world_size); - __syncthreads(); - - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < chunk_numel; idx += gridDim.x * blockDim.x) { - if (dtype_enum == 0) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __nv_bfloat16* src = (const __nv_bfloat16*)ptrs[r]; - sum += __bfloat162float(src[chunk_offset + idx]); - } - __nv_bfloat16* out = (__nv_bfloat16*)ptrs[rank]; - out[chunk_offset + idx] = __float2bfloat16(sum); - } else if (dtype_enum == 1) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* src = (const float*)ptrs[r]; - sum += src[chunk_offset + idx]; - } - float* out = (float*)ptrs[rank]; - out[chunk_offset + idx] = sum; - } else if (dtype_enum == 2) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __half* src = (const __half*)ptrs[r]; - sum += __half2float(src[chunk_offset + idx]); - } - __half* out = (__half*)ptrs[rank]; - out[chunk_offset + idx] = __float2half(sum); - } - } - - __syncthreads(); - // Barrier to ensure no rank modifies the symmetric buffer before others finish reading - blockwise_barrier_reusable(sync_ptrs, 1, chunk_id, block_id, rank, world_size); -} - -// --------------------------------------------------------------------------- -// Python Launchers -// --------------------------------------------------------------------------- -void launch_multimem_chunked( - uint64_t multicast_ptr, - torch::Tensor sync_ptrs_tensor, - int64_t chunk_offset_128, - int64_t chunk_numel_128, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride, - int chunk_id, - uint64_t stream_ptr -) { - const uint64_t* d_sync = reinterpret_cast(sync_ptrs_tensor.data_ptr()); - cudaStream_t stream = reinterpret_cast(stream_ptr); - - multimem_allreduce_bf16_chunked_kernel<<>>( - multicast_ptr, d_sync, chunk_offset_128, chunk_numel_128, world_size, rank, block_stride, chunk_id - ); -} - -void launch_p2p_chunked( - torch::Tensor sync_ptrs_tensor, - torch::Tensor ptrs_tensor, - int64_t chunk_offset, - int64_t chunk_numel, - int world_size, - int rank, - int chunk_id, - int dtype_enum, - uint64_t stream_ptr -) { - const uint64_t* d_sync = reinterpret_cast(sync_ptrs_tensor.data_ptr()); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - - cudaStream_t stream = reinterpret_cast(stream_ptr); - - int threads = 512; - int blocks = 32; - if (chunk_numel < blocks * threads) { - blocks = (chunk_numel + threads - 1) / threads; - if (blocks == 0) blocks = 1; - } - - p2p_allreduce_chunked_kernel<<>>( - d_sync, d_ptrs, chunk_offset, chunk_numel, world_size, rank, chunk_id, dtype_enum - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_chunked", &launch_multimem_chunked); - m.def("launch_p2p_chunked", &launch_p2p_chunked); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("chunked_gemm_allreduce", CUDA_SRC) - return _ext - -BYTES_PER_THREAD = 16 -MAX_NUM_BLOCKS = 4 -MAX_BLOCK_SIZE = 1024 - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = 8 - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - if num_threads < MAX_BLOCK_SIZE: - block_size = max(32, world_size) # Guarantee enough threads for blockwise barriers - while block_size < num_threads and block_size < MAX_BLOCK_SIZE: - block_size *= 2 - num_blocks = 1 - else: - block_size = MAX_BLOCK_SIZE - num_blocks = min((num_threads + MAX_BLOCK_SIZE - 1) // MAX_BLOCK_SIZE, MAX_NUM_BLOCKS) - return num_blocks, block_size, block_size - -_resource_cache = {} - -def _get_resources(shape, dtype, device, world_size): - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - C_symm = symm_mem.empty(shape, device=device, dtype=dtype) - C_hdl = symm_mem.rendezvous(C_symm, dist.group.WORLD) - C_ptrs = torch.tensor(C_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - # Preallocate self-cleaning blockwise barriers: - # Max barriers = 2, Max chunks = 128, Max blocks = 32 - sync_numel = 2 * 128 * 32 * world_size - sync_buf = symm_mem.empty((sync_numel,), device=device, dtype=torch.int32) - sync_buf.zero_() # Cleared once at creation; barriers handle local 1->0 cleanup - sync_hdl = symm_mem.rendezvous(sync_buf, dist.group.WORLD) - sync_ptrs = torch.tensor(sync_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - comm_stream = torch.cuda.Stream() - events = [torch.cuda.Event() for _ in range(128)] - - res = { - "C_symm": C_symm, "C_hdl": C_hdl, "C_ptrs": C_ptrs, - "sync_buf": sync_buf, "sync_ptrs": sync_ptrs, - "comm_stream": comm_stream, "events": events - } - _resource_cache[key] = res - return res - -def get_num_chunks(M: int) -> int: - # Heuristics for overlapping - if M <= 512: - return 1 - elif M <= 2048: - return 2 - else: - return min(4, (M + 1023) // 1024) - -@torch.no_grad() -def solution(A_local: torch.Tensor, B_local: torch.Tensor) -> torch.Tensor: - if not dist.is_initialized(): - return torch.matmul(A_local, B_local) - - M, K = A_local.shape - K_B, N = B_local.shape - assert K == K_B - - if M == 0 or N == 0 or K == 0: - C = torch.matmul(A_local, B_local) - dist.all_reduce(C, op=dist.ReduceOp.SUM) - return C - - rank = dist.get_rank() - world_size = dist.get_world_size() - - # Pre-compile on rank 0 sequentially to prevent NCCL/CUDA timeout issues - if rank == 0: - _get_ext() - dist.barrier() - - num_chunks = get_num_chunks(M) - chunk_size = (M + num_chunks - 1) // num_chunks - num_chunks = (M + chunk_size - 1) // chunk_size # Evict empty chunks - - res = _get_resources((M, N), A_local.dtype, A_local.device, world_size) - C_symm = res["C_symm"] - C_hdl = res["C_hdl"] - C_ptrs = res["C_ptrs"] - sync_ptrs = res["sync_ptrs"] - comm_stream = res["comm_stream"] - events = res["events"] - - A_local = A_local.contiguous() - B_local = B_local.contiguous() - - use_multimem = (A_local.dtype == torch.bfloat16) and (C_symm.numel() % 8 == 0) and hasattr(C_hdl, "multicast_ptr") - if use_multimem: - multicast_ptr = int(C_hdl.multicast_ptr) - - for i in range(num_chunks): - start_m = i * chunk_size - end_m = min(M, (i + 1) * chunk_size) - if start_m >= M: - break - - chunk_m = end_m - start_m - chunk_numel = chunk_m * N - chunk_offset = start_m * N - - # 1. Compute GEMM chunk natively overlapping previous iteration's async all-reduce - torch.matmul(A_local[start_m:end_m, :], B_local, out=C_symm[start_m:end_m, :]) - events[i].record(torch.cuda.current_stream()) - - # 2. Asynchronous All-Reduce - comm_stream.wait_event(events[i]) - with torch.cuda.stream(comm_stream): - if use_multimem and (chunk_numel % 8 == 0) and (chunk_offset % 8 == 0): - chunk_numel_128 = chunk_numel // 8 - chunk_offset_128 = chunk_offset // 8 - num_blocks, block_size, block_stride = _multimem_launch_config(chunk_numel, world_size) - - _get_ext().launch_multimem_chunked( - multicast_ptr, sync_ptrs, chunk_offset_128, chunk_numel_128, - world_size, rank, num_blocks, block_size, block_stride, i, comm_stream.cuda_stream - ) - else: - dtype_enum = 0 if A_local.dtype == torch.bfloat16 else (1 if A_local.dtype == torch.float32 else 2) - _get_ext().launch_p2p_chunked( - sync_ptrs, C_ptrs, chunk_offset, chunk_numel, - world_size, rank, i, dtype_enum, comm_stream.cuda_stream - ) - - torch.cuda.current_stream().wait_stream(comm_stream) - - # Strict wait ensuring safe buffer extraction before next `solution` call zeroes/touches local dependencies - dist.barrier() - return C_symm.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_cuda.py deleted file mode 100755 index 930bf63..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/14_gemm_allscatter_cuda.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -Distributed GEMM with All-Gather (scatter). - -Strategy: -- Maximize overlap by chunking the M dimension. Local compute (A_chunk @ B) is overlapped - with asynchronous peer-to-peer multicast of the previous chunk. -- Avoids NCCL by using a custom CUDA kernel that writes directly to peers' memory via - UVA and symmetric memory buffers. -- Exploits Hopper NVSwitch `multimem.st` (Hardware Broadcast) to write the output to all - peers simultaneously with a single instruction, drastically reducing NVLink traffic. -- Leverages device-side barriers (`hdl.barrier`) for lightning-fast GPU stream synchronization. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include - -__device__ __forceinline__ void multimem_st_128(const uint64_t* addr, uint4 val) { - // Hardware broadcast: write 128 bits to all multicast peers simultaneously - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(addr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory"); -} - -__global__ void push_multimem_128( - const uint4* __restrict__ src, - uint64_t multicast_ptr, - int chunk_rows, - int N_local_128, - int N_128, - int start_col_128 -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total_elements = chunk_rows * N_local_128; - if (idx >= total_elements) return; - - int row = idx / N_local_128; - int col = idx % N_local_128; - - uint4 val = src[idx]; - int out_offset = row * N_128 + start_col_128 + col; - - // cast to uint64_t* which is 8 bytes, so we multiply offset by 2 to step by 16 bytes - uint64_t* dst = reinterpret_cast(multicast_ptr) + out_offset * 2; - multimem_st_128(dst, val); -} - -__global__ void push_p2p_128( - const uint4* __restrict__ src, - const uint64_t* __restrict__ peer_ptrs, - int world_size, - int chunk_rows, - int N_local_128, - int N_128, - int start_col_128 -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total_elements = chunk_rows * N_local_128; - if (idx >= total_elements) return; - - int row = idx / N_local_128; - int col = idx % N_local_128; - - uint4 val = src[idx]; - int out_offset = row * N_128 + start_col_128 + col; - - #pragma unroll - for (int p = 0; p < world_size; ++p) { - uint4* dst = reinterpret_cast(peer_ptrs[p]); - dst[out_offset] = val; - } -} - -template -__global__ void push_p2p_scalar( - const T* __restrict__ src, - const uint64_t* __restrict__ peer_ptrs, - int world_size, - int chunk_rows, - int N_local, - int N, - int start_col -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total_elements = chunk_rows * N_local; - if (idx >= total_elements) return; - - int row = idx / N_local; - int col = idx % N_local; - - T val = src[idx]; - int out_offset = row * N + start_col + col; - - #pragma unroll - for (int p = 0; p < world_size; ++p) { - T* dst = reinterpret_cast(peer_ptrs[p]); - dst[out_offset] = val; - } -} - -void launch_push( - torch::Tensor src, - torch::Tensor peer_ptrs, - int64_t multicast_ptr_int, - int chunk_rows, - int N_local, - int N, - int start_col -) { - int element_size = src.element_size(); - bool use_128 = ((N_local * element_size) % 16 == 0) && - ((N * element_size) % 16 == 0) && - ((start_col * element_size) % 16 == 0); - - int threads = 256; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - uint64_t multicast_ptr = static_cast(multicast_ptr_int); - - if (use_128) { - int N_local_128 = (N_local * element_size) / 16; - int N_128 = (N * element_size) / 16; - int start_col_128 = (start_col * element_size) / 16; - int total = chunk_rows * N_local_128; - if (total == 0) return; - - int blocks = (total + threads - 1) / threads; - const uint4* src_ptr = reinterpret_cast(src.data_ptr()); - - if (multicast_ptr != 0) { - push_multimem_128<<>>( - src_ptr, multicast_ptr, chunk_rows, N_local_128, N_128, start_col_128 - ); - } else { - const uint64_t* ptrs = reinterpret_cast(peer_ptrs.data_ptr()); - int world_size = peer_ptrs.size(0); - push_p2p_128<<>>( - src_ptr, ptrs, world_size, chunk_rows, N_local_128, N_128, start_col_128 - ); - } - } else { - int total = chunk_rows * N_local; - if (total == 0) return; - - int blocks = (total + threads - 1) / threads; - const uint64_t* ptrs = reinterpret_cast(peer_ptrs.data_ptr()); - int world_size = peer_ptrs.size(0); - - if (element_size == 4) { - push_p2p_scalar<<>>( - reinterpret_cast(src.data_ptr()), ptrs, world_size, chunk_rows, N_local, N, start_col - ); - } else if (element_size == 2) { - push_p2p_scalar<<>>( - reinterpret_cast(src.data_ptr()), ptrs, world_size, chunk_rows, N_local, N, start_col - ); - } else { - push_p2p_scalar<<>>( - reinterpret_cast(src.data_ptr()), ptrs, world_size, chunk_rows, N_local*element_size, N*element_size, start_col*element_size - ); - } - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push", &launch_push, "Asynchronously push chunk of C_local to all peers"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - if dist.get_rank() == 0: - _ext = compile_cuda_extension("gemm_allscatter_push_ext", CUDA_SRC) - dist.barrier() - if dist.get_rank() != 0: - _ext = compile_cuda_extension("gemm_allscatter_push_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(M: int, N_local: int, N: int, dtype: torch.dtype, device: torch.device): - key = (M, N_local, N, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - # Global C symmetric buffer. Written natively via P2P/Multimem ST - C_symm = symm_mem.empty((M, N), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(C_symm, dist.group.WORLD) - - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - multicast_ptr = getattr(hdl, "multicast_ptr", None) - multicast_ptr_int = int(multicast_ptr) if multicast_ptr is not None else 0 - - # Preallocate compute target buffer to avoid recurrent allocations inside hot loop - C_local_buffer = torch.empty((M, N_local), dtype=dtype, device=device) - - comm_stream = torch.cuda.Stream(device=device) - events = [torch.cuda.Event() for _ in range(64)] - - res = (C_symm, hdl, ptrs_tensor, multicast_ptr_int, C_local_buffer, comm_stream, events) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - if not dist.is_initialized(): - return torch.matmul(A, B) - - rank = dist.get_rank() - world_size = dist.get_world_size() - M, K = A.shape - K_B, N_local = B.shape - N = world_size * N_local - - _get_ext() - - (C_symm, hdl, ptrs_tensor, multicast_ptr_int, - C_local_buffer, comm_stream, events) = _get_resources(M, N_local, N, A.dtype, A.device) - - compute_stream = torch.cuda.current_stream() - - # Determine schedule (chunking) to allow comm stream to hide behind compute stream - chunk_size = max(256, M // 4) - if M <= 512: - chunk_size = M - num_chunks = (M + chunk_size - 1) // chunk_size - if num_chunks > len(events): - chunk_size = (M + len(events) - 1) // len(events) - num_chunks = len(events) - - start_col = rank * N_local - - for i in range(num_chunks): - start_m = i * chunk_size - end_m = min((i + 1) * chunk_size, M) - chunk_rows = end_m - start_m - if chunk_rows <= 0: - break - - # Slicing creates views into preallocated buffers - A_chunk = A[start_m:end_m, :] - C_local_chunk = C_local_buffer[start_m:end_m, :] - - # Step 1: Execute dense math on the main stream (uses Tensor Cores) - torch.matmul(A_chunk, B, out=C_local_chunk) - events[i].record(compute_stream) - - # Step 2: Push result to the rest of the world concurrently via separate stream - comm_stream.wait_event(events[i]) - with torch.cuda.stream(comm_stream): - _get_ext().launch_push( - C_local_chunk, - ptrs_tensor, - multicast_ptr_int, - chunk_rows, - N_local, - N, - start_col - ) - - # Ensure compute stream tracks memory-copy stream - compute_stream.wait_stream(comm_stream) - - # Fast hardware-assisted stream barrier 0: Ensure all peers finished pushing data out - hdl.barrier(channel=0) - - # Once everyone is synchronized, form the local copy of the complete output - out = C_symm.clone() - - # Fast hardware-assisted stream barrier 1: Ensure local reading logic finishes - # before returning, securing the `C_symm` buffer against being corrupted in the NEXT iteration - hdl.barrier(channel=1) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_cuda.py deleted file mode 100755 index b819ecc..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/15_combined_sharded_gemms_cuda.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -Strategy: -- **Algorithmic Reduction:** The reference code all-gathers `x` and computes the MLP on the full sequence, then throws away `(world_size - 1)/world_size` of the result via a masked reduce-scatter. We replace this entire sequence with a direct **All-to-All** routing of the required `M_local` sequence chunks to each rank, followed by the MLP computed only on the local shard. This mathematically identical approach reduces FLOPs and communication volume by `world_size`x. -- **Device-Side P2P Symmetric Memory:** We eliminate the PyTorch collective overhead by allocating an `x_full_loc` receiver buffer via `torch.distributed._symmetric_memory` and explicitly scattering the input tensors directly into peer memory using a custom vectorized CUDA kernel over NVLink. -- **Compute-Communication Pipelining:** We split the sequence dimension `M_local` into multiple chunks. Using dual CUDA streams and atomic device-side barriers (`hdl.barrier`), we overlap the cross-NVLink transfer of chunk `c+1` with the local GEMM and SiLU computations of chunk `c`, hiding the network transfer time behind the compute. -""" - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void all_to_all_chunk_kernel( - const __nv_bfloat16* __restrict__ x_local, - const long long* __restrict__ dest_ptrs, - int M_local, - int H_local, - int H, - int rank, - int world_size, - int chunk_start, - int chunk_size -) { - int64_t total_elements = (int64_t)chunk_size * H_local * world_size; - int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - - // Fast vectorized path (128-bit loads/stores) when aligned - if (H_local % 8 == 0) { - int64_t total_vecs = total_elements / 8; - for (int64_t idx = tid; idx < total_vecs; idx += gridDim.x * blockDim.x) { - int64_t elem_idx = idx * 8; - int dest_rank = elem_idx / (chunk_size * H_local); - int rem = elem_idx % (chunk_size * H_local); - int r = rem / H_local; - int c = rem % H_local; - - int src_r = dest_rank * M_local + chunk_start + r; - int dst_r = chunk_start + r; - - const uint4* src = (const uint4*)x_local; - uint4* dest = (uint4*)dest_ptrs[dest_rank]; - - int src_vec_idx = (src_r * H_local + c) / 8; - int dst_vec_idx = (dst_r * H + rank * H_local + c) / 8; - - dest[dst_vec_idx] = src[src_vec_idx]; - } - } else { - // Scalar fallback path - for (int64_t idx = tid; idx < total_elements; idx += gridDim.x * blockDim.x) { - int dest_rank = idx / (chunk_size * H_local); - int rem = idx % (chunk_size * H_local); - int r = rem / H_local; - int c = rem % H_local; - - int src_r = dest_rank * M_local + chunk_start + r; - int dst_r = chunk_start + r; - - const __nv_bfloat16* src = x_local; - __nv_bfloat16* dest = (__nv_bfloat16*)dest_ptrs[dest_rank]; - - dest[dst_r * H + rank * H_local + c] = src[src_r * H_local + c]; - } - } -} - -void launch_all_to_all_chunk( - torch::Tensor x_local, - torch::Tensor ptrs_tensor, - int M_local, - int H_local, - int H, - int rank, - int world_size, - int chunk_start, - int chunk_size -) { - int64_t total_elements = (int64_t)chunk_size * H_local * world_size; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - if (H_local % 8 == 0) { - blocks = ((total_elements / 8) + threads - 1) / threads; - } - if (blocks == 0) blocks = 1; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - all_to_all_chunk_kernel<<>>( - reinterpret_cast(x_local.data_ptr()), - reinterpret_cast(ptrs_tensor.data_ptr()), - M_local, - H_local, - H, - rank, - world_size, - chunk_start, - chunk_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_all_to_all_chunk", &launch_all_to_all_chunk, "All-to-all chunk copy kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("all_to_all_pipeline_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_buf(M_local, H, dtype, device): - global _symm_cache - key = (M_local, H, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty((M_local, H), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return _symm_cache[key] - -_stream_cache = None -def _get_stream1(): - global _stream_cache - if _stream_cache is None: - _stream_cache = torch.cuda.Stream() - return _stream_cache - -_events_cache = [] -def _get_events(n): - global _events_cache - while len(_events_cache) < n: - _events_cache.append(torch.cuda.Event()) - return _events_cache[:n] - - -@torch.no_grad() -def solution( - x_local: torch.Tensor, - W1: torch.Tensor, - W2: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert x_local.is_cuda and W1.is_cuda and W2.is_cuda, "Inputs must be CUDA tensors" - - rank = dist.get_rank() - world_size = dist.get_world_size() - - if _ext is None: - if rank == 0: - _get_ext() - dist.barrier() - _get_ext() - - M, H_local = x_local.shape - H, ffn_dim = W1.shape - M_local = M // world_size - - x_local = x_local.contiguous() - - # Allocate symmetric receiver buffer for this rank's block of the sequence - buf, hdl, ptrs_tensor = _get_symm_buf(M_local, H, x_local.dtype, x_local.device) - - stream1 = _get_stream1() - # Make the comm stream wait for any pending x_local preparation on the default stream - stream1.wait_stream(torch.cuda.current_stream()) - - # Pipeline chunks for overlapping matmul computation with NVLink peer-to-peer copies - num_chunks = 2 if (M_local >= 128 and M_local % 2 == 0) else 1 - chunk_size = M_local // num_chunks - events = _get_events(num_chunks) - - # Launch communication sequence entirely on stream1 - with torch.cuda.stream(stream1): - # Strict pre-write sync to ensure peers finished matmuls from previous steps - hdl.barrier(channel=0) - - for c in range(num_chunks): - chunk_start = c * chunk_size - _get_ext().launch_all_to_all_chunk( - x_local, ptrs_tensor, M_local, H_local, H, - rank, world_size, chunk_start, chunk_size - ) - # Ensure chunk c has fully arrived on all ranks before computation - hdl.barrier(channel=0) - events[c].record(stream1) - - y_local = torch.empty((M_local, H), dtype=x_local.dtype, device=x_local.device) - - # Launch computation on the default stream, synced with stream1 chunks - for c in range(num_chunks): - torch.cuda.current_stream().wait_event(events[c]) - chunk_start = c * chunk_size - - # Pull the fully assembled row block of x_full - x_chunk = buf[chunk_start : chunk_start + chunk_size, :] - - # Execute MLP exclusively on this rank's required sequence shard - z_chunk = torch.matmul(x_chunk, W1) - a_chunk = F.silu(z_chunk) - y_local[chunk_start : chunk_start + chunk_size, :] = torch.matmul(a_chunk, W2) - - # Prevent stream1 resources from being prematurely cleaned up - torch.cuda.current_stream().wait_stream(stream1) - - # Matching the reference spec synchronization - dist.barrier() - return y_local \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_cuda.py deleted file mode 100755 index 2fe3876..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/16_gemm_reducescatter_cuda.py +++ /dev/null @@ -1,279 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Vectorized helper for software reduction fallback -__device__ __forceinline__ void sum_bf16_8(float* acc, const uint4& val) { - const __nv_bfloat162* v = reinterpret_cast(&val); - #pragma unroll - for(int i = 0; i < 4; ++i) { - float2 f = __bfloat1622float2(v[i]); - acc[2 * i] += f.x; - acc[2 * i + 1] += f.y; - } -} - -__device__ __forceinline__ uint4 pack_bf16_8(const float* acc) { - uint4 res; - __nv_bfloat162* v = reinterpret_cast<__nv_bfloat162*>(&res); - #pragma unroll - for(int i = 0; i < 4; ++i) { - v[i] = __floats2bfloat162_rn(acc[2 * i], acc[2 * i + 1]); - } - return res; -} - -__global__ void reduce_scatter_fallback_kernel( - const uint64_t* __restrict__ symm_C_ptrs, - __nv_bfloat16* __restrict__ out_C, - uint32_t* __restrict__ my_flags, - uint32_t expected_value, - int chunk_idx, - int64_t chunk_size, - int world_size -) { - // 1. Wait for all peers to signal they have finished this chunk - if (threadIdx.x == 0) { - for (int p = 0; p < world_size; ++p) { - uint32_t val = 0; - do { - asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(val) : "l"(my_flags + p) : "memory"); - } while (val < expected_value); - } - } - __syncthreads(); - - // 2. Reduce the chunk - int64_t offset = (int64_t)chunk_idx * chunk_size; - int64_t num_vecs = chunk_size / 8; - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - for (int64_t i = tid; i < num_vecs; i += stride) { - float sum[8] = {0.0f}; - for (int p = 0; p < world_size; ++p) { - const uint64_t byte_offset = (offset + i * 8) * sizeof(__nv_bfloat16); - const uint4* peer_C_vec = reinterpret_cast(symm_C_ptrs[p] + byte_offset); - uint4 val = peer_C_vec[0]; - sum_bf16_8(sum, val); - } - uint4 out_val = pack_bf16_8(sum); - reinterpret_cast(out_C)[i] = out_val; - } - - // Tail reduction for remaining elements if chunk_size is not perfectly divisible by 8 - if (tid == 0) { - for (int64_t i = num_vecs * 8; i < chunk_size; ++i) { - float sum = 0.0f; - for (int p = 0; p < world_size; ++p) { - const __nv_bfloat16* peer_C = reinterpret_cast(symm_C_ptrs[p]); - sum += __bfloat162float(peer_C[offset + i]); - } - out_C[i] = __float2bfloat16(sum); - } - } -} - -__global__ void reduce_scatter_multimem_kernel( - uint64_t multicast_base, - __nv_bfloat16* __restrict__ out_C, - uint32_t* __restrict__ my_flags, - uint32_t expected_value, - int chunk_idx, - int64_t chunk_size, - int world_size -) { - // 1. Wait for all peers to signal they have finished this chunk - if (threadIdx.x == 0) { - for (int p = 0; p < world_size; ++p) { - uint32_t val = 0; - do { - asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(val) : "l"(my_flags + p) : "memory"); - } while (val < expected_value); - } - } - __syncthreads(); - - // 2. Hardware NVSwitch multimem reduction - int64_t offset = (int64_t)chunk_idx * chunk_size; - int64_t num_vecs = chunk_size / 8; - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - for (int64_t i = tid; i < num_vecs; i += stride) { - uint64_t byte_offset = (offset + i * 8) * sizeof(__nv_bfloat16); - uint64_t ptr = multicast_base + byte_offset; - - uint32_t r0, r1, r2, r3; - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(ptr) - : "memory"); - - uint32_t* out_dst = reinterpret_cast(out_C + i * 8); - out_dst[0] = r0; - out_dst[1] = r1; - out_dst[2] = r2; - out_dst[3] = r3; - } -} - -__global__ void send_signal_kernel(uint32_t* target_flags, int index, uint32_t value) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - // Release consistency ensures prior symm_C matmul stores are visible to target device - asm volatile("st.release.sys.global.u32 [%0], %1;" :: "l"(target_flags + index), "r"(value) : "memory"); - } -} - -void launch_send_signal(uint64_t target_flags_ptr, int index, uint32_t value) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - send_signal_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(target_flags_ptr), index, value); -} - -void launch_reduce_scatter( - uint64_t multicast_ptr, - torch::Tensor symm_C_ptrs_tensor, - torch::Tensor out_C, - uint64_t my_flags_ptr, - uint32_t expected_value, - int chunk_idx, - int64_t chunk_size, - int world_size -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 512; - int blocks = std::min((int)((chunk_size / 8 + threads - 1) / threads), 4096); - if (blocks == 0) blocks = 1; - - bool use_multimem = (multicast_ptr != 0) && (chunk_size % 8 == 0); - - if (use_multimem) { - reduce_scatter_multimem_kernel<<>>( - multicast_ptr, - (__nv_bfloat16*)out_C.data_ptr(), - reinterpret_cast(my_flags_ptr), - expected_value, - chunk_idx, - chunk_size, - world_size - ); - } else { - const uint64_t* ptrs = (const uint64_t*)symm_C_ptrs_tensor.data_ptr(); - reduce_scatter_fallback_kernel<<>>( - ptrs, - (__nv_bfloat16*)out_C.data_ptr(), - reinterpret_cast(my_flags_ptr), - expected_value, - chunk_idx, - chunk_size, - world_size - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_send_signal", &launch_send_signal); - m.def("launch_reduce_scatter", &launch_reduce_scatter); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gemm_reduce_scatter_ext", CUDA_SRC) - return _ext - -_resource_cache = {} -_buf_idx = 0 -_invocation_count = 0 - -def _get_resources(shape, dtype, device, world_size): - global _resource_cache - key = (shape, dtype, device, world_size) - if key in _resource_cache: - return _resource_cache[key] - - buffers = [] - for _ in range(2): # Double-buffering prevents overwriting during tight pipelined loops - symm_C = symm_mem.empty(shape, device=device, dtype=dtype) - symm_C_hdl = symm_mem.rendezvous(symm_C, dist.group.WORLD) - symm_C_ptrs = torch.tensor(symm_C_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - symm_flags = symm_mem.empty((world_size,), dtype=torch.int32, device=device) - symm_flags.zero_() - flags_hdl = symm_mem.rendezvous(symm_flags, dist.group.WORLD) - flags_ptrs = flags_hdl.buffer_ptrs - - buffers.append((symm_C, symm_C_hdl, symm_C_ptrs, symm_flags, flags_hdl, flags_ptrs)) - - torch.cuda.synchronize() - dist.barrier() - _resource_cache[key] = buffers - return buffers - -@torch.no_grad() -def solution(A_local: torch.Tensor, B_local: torch.Tensor) -> torch.Tensor: - global _buf_idx, _invocation_count - - W = dist.get_world_size() - rank = dist.get_rank() - - M, K_local = A_local.shape - _, N = B_local.shape - M_local = M // W - - _invocation_count += 1 - invoc_id = _invocation_count - - buffers = _get_resources((M, N), A_local.dtype, A_local.device, W) - symm_C, symm_C_hdl, symm_C_ptrs, symm_flags, flags_hdl, flags_ptrs = buffers[_buf_idx] - _buf_idx = (_buf_idx + 1) % 2 - - ext = _get_ext() - - A_contig = A_local.contiguous() - B_contig = B_local.contiguous() - - # Chunked overlap loop: Staggered chunks pipelined to minimize peer spin-wait. - for i in range(W): - c = (rank + i) % W - start_row = c * M_local - end_row = start_row + M_local - - A_slice = A_contig[start_row:end_row, :] - C_slice = symm_C[start_row:end_row, :] - - # Output directly to asymmetric/symmetric shared device memory segment - torch.matmul(A_slice, B_contig, out=C_slice) - - # Async kernel queuing: Release device-resident chunk-level completion signal - ext.launch_send_signal(flags_ptrs[c], rank, invoc_id) - - out_C = torch.empty((M_local, N), dtype=A_local.dtype, device=A_local.device) - chunk_size = M_local * N - - multicast_ptr = int(symm_C_hdl.multicast_ptr) if hasattr(symm_C_hdl, 'multicast_ptr') and symm_C_hdl.multicast_ptr is not None else 0 - - # Launch spin-wait device-side reduction; pulls via NVSwitch MMU multimem pointers where possible - ext.launch_reduce_scatter( - multicast_ptr, - symm_C_ptrs, - out_C, - symm_flags.data_ptr(), - invoc_id, - rank, # Focus purely on chunk subset - chunk_size, - W - ) - - return out_C \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_cuda.py deleted file mode 100755 index b354b63..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/17_rope_allgather_cuda.py +++ /dev/null @@ -1,313 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Tuple -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__device__ __forceinline__ void compute_rope( - void* q_v, void* k_v, void* q_p, void* k_p, void* c_v, void* s_v, - void* qo_v, void* ko_v, int d_idx, int half_D -) { - __nv_bfloat16* q = (__nv_bfloat16*)q_v; - __nv_bfloat16* k = (__nv_bfloat16*)k_v; - __nv_bfloat16* qp = (__nv_bfloat16*)q_p; - __nv_bfloat16* kp = (__nv_bfloat16*)k_p; - __nv_bfloat16* c = (__nv_bfloat16*)c_v; - __nv_bfloat16* s = (__nv_bfloat16*)s_v; - __nv_bfloat16* qo = (__nv_bfloat16*)qo_v; - __nv_bfloat16* ko = (__nv_bfloat16*)ko_v; - - #pragma unroll - for (int i = 0; i < V; ++i) { - float q_f = __bfloat162float(q[i]); - float k_f = __bfloat162float(k[i]); - float qp_f = __bfloat162float(qp[i]); - float kp_f = __bfloat162float(kp[i]); - float c_f = __bfloat162float(c[i]); - float s_f = __bfloat162float(s[i]); - - float q_rot = (d_idx < half_D) ? (-qp_f) : (qp_f); - float k_rot = (d_idx < half_D) ? (-kp_f) : (kp_f); - - qo[i] = __float2bfloat16(q_f * c_f + q_rot * s_f); - ko[i] = __float2bfloat16(k_f * c_f + k_rot * s_f); - } -} - -template -__global__ void rope_multicast_kernel( - const uint8_t* __restrict__ q_local, - const uint8_t* __restrict__ k_local, - const uint8_t* __restrict__ cos_local, - const uint8_t* __restrict__ sin_local, - uint64_t mcast_q, - const uint64_t* __restrict__ peer_ptrs_q, - uint64_t mcast_k, - const uint64_t* __restrict__ peer_ptrs_k, - int B, int S_local, int H, int D, - int world_size, int rank -) { - int num_elements = B * S_local * H * D; - int num_vecs = num_elements / VEC_SIZE; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - - int half_D = D / 2; - - for (int i = tid; i < num_vecs; i += stride) { - int d_idx = (i * VEC_SIZE) % D; - int h_idx = ((i * VEC_SIZE) / D) % H; - int s_idx = ((i * VEC_SIZE) / (D * H)) % S_local; - int b_idx = ((i * VEC_SIZE) / (D * H * S_local)); - - int partner_d_idx = (d_idx < half_D) ? (d_idx + half_D) : (d_idx - half_D); - - size_t offset_main = (size_t)i * VEC_SIZE * 2; - size_t offset_partner = ((size_t)b_idx * S_local * H * D + (size_t)s_idx * H * D + (size_t)h_idx * D + partner_d_idx) * 2; - size_t offset_cos_sin = ((size_t)b_idx * S_local * D + (size_t)s_idx * D + d_idx) * 2; - - int s_global = rank * S_local + s_idx; - size_t offset_global = ((size_t)b_idx * (world_size * S_local) * H * D + (size_t)s_global * H * D + (size_t)h_idx * D + d_idx) * 2; - - if constexpr (VEC_SIZE == 8) { - uint4 q_vec = *(uint4*)(q_local + offset_main); - uint4 k_vec = *(uint4*)(k_local + offset_main); - uint4 q_partner = *(uint4*)(q_local + offset_partner); - uint4 k_partner = *(uint4*)(k_local + offset_partner); - uint4 cos_vec = *(uint4*)(cos_local + offset_cos_sin); - uint4 sin_vec = *(uint4*)(sin_local + offset_cos_sin); - - uint4 q_out, k_out; - compute_rope<8>(&q_vec, &k_vec, &q_partner, &k_partner, &cos_vec, &sin_vec, &q_out, &k_out, d_idx, half_D); - - if (mcast_q != 0) { - uint64_t addr_q = mcast_q + offset_global; - uint64_t addr_k = mcast_k + offset_global; - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" :: "l"(addr_q), "r"(q_out.x), "r"(q_out.y), "r"(q_out.z), "r"(q_out.w) : "memory"); - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" :: "l"(addr_k), "r"(k_out.x), "r"(k_out.y), "r"(k_out.z), "r"(k_out.w) : "memory"); - } else { - for (int r = 0; r < world_size; ++r) { - *(uint4*)(peer_ptrs_q[r] + offset_global) = q_out; - *(uint4*)(peer_ptrs_k[r] + offset_global) = k_out; - } - } - } - else if constexpr (VEC_SIZE == 4) { - uint2 q_vec = *(uint2*)(q_local + offset_main); - uint2 k_vec = *(uint2*)(k_local + offset_main); - uint2 q_partner = *(uint2*)(q_local + offset_partner); - uint2 k_partner = *(uint2*)(k_local + offset_partner); - uint2 cos_vec = *(uint2*)(cos_local + offset_cos_sin); - uint2 sin_vec = *(uint2*)(sin_local + offset_cos_sin); - - uint2 q_out, k_out; - compute_rope<4>(&q_vec, &k_vec, &q_partner, &k_partner, &cos_vec, &sin_vec, &q_out, &k_out, d_idx, half_D); - - if (mcast_q != 0) { - uint64_t addr_q = mcast_q + offset_global; - uint64_t addr_k = mcast_k + offset_global; - asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1, %2};" :: "l"(addr_q), "r"(q_out.x), "r"(q_out.y) : "memory"); - asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1, %2};" :: "l"(addr_k), "r"(k_out.x), "r"(k_out.y) : "memory"); - } else { - for (int r = 0; r < world_size; ++r) { - *(uint2*)(peer_ptrs_q[r] + offset_global) = q_out; - *(uint2*)(peer_ptrs_k[r] + offset_global) = k_out; - } - } - } - else if constexpr (VEC_SIZE == 2) { - uint32_t q_vec = *(uint32_t*)(q_local + offset_main); - uint32_t k_vec = *(uint32_t*)(k_local + offset_main); - uint32_t q_partner = *(uint32_t*)(q_local + offset_partner); - uint32_t k_partner = *(uint32_t*)(k_local + offset_partner); - uint32_t cos_vec = *(uint32_t*)(cos_local + offset_cos_sin); - uint32_t sin_vec = *(uint32_t*)(sin_local + offset_cos_sin); - - uint32_t q_out, k_out; - compute_rope<2>(&q_vec, &k_vec, &q_partner, &k_partner, &cos_vec, &sin_vec, &q_out, &k_out, d_idx, half_D); - - if (mcast_q != 0) { - uint64_t addr_q = mcast_q + offset_global; - uint64_t addr_k = mcast_k + offset_global; - asm volatile("multimem.st.relaxed.sys.global.f32 [%0], %1;" :: "l"(addr_q), "r"(q_out) : "memory"); - asm volatile("multimem.st.relaxed.sys.global.f32 [%0], %1;" :: "l"(addr_k), "r"(k_out) : "memory"); - } else { - for (int r = 0; r < world_size; ++r) { - *(uint32_t*)(peer_ptrs_q[r] + offset_global) = q_out; - *(uint32_t*)(peer_ptrs_k[r] + offset_global) = k_out; - } - } - } - else if constexpr (VEC_SIZE == 1) { - uint16_t q_vec = *(uint16_t*)(q_local + offset_main); - uint16_t k_vec = *(uint16_t*)(k_local + offset_main); - uint16_t q_partner = *(uint16_t*)(q_local + offset_partner); - uint16_t k_partner = *(uint16_t*)(k_local + offset_partner); - uint16_t cos_vec = *(uint16_t*)(cos_local + offset_cos_sin); - uint16_t sin_vec = *(uint16_t*)(sin_local + offset_cos_sin); - - uint16_t q_out, k_out; - compute_rope<1>(&q_vec, &k_vec, &q_partner, &k_partner, &cos_vec, &sin_vec, &q_out, &k_out, d_idx, half_D); - - for (int r = 0; r < world_size; ++r) { - *(uint16_t*)(peer_ptrs_q[r] + offset_global) = q_out; - *(uint16_t*)(peer_ptrs_k[r] + offset_global) = k_out; - } - } - } -} - -void launch_rope_multicast( - torch::Tensor q_local, - torch::Tensor k_local, - torch::Tensor cos_local, - torch::Tensor sin_local, - uint64_t mcast_q, - torch::Tensor peer_ptrs_q_tensor, - uint64_t mcast_k, - torch::Tensor peer_ptrs_k_tensor, - int world_size, - int rank -) { - int B = q_local.size(0); - int S_local = q_local.size(1); - int H = q_local.size(2); - int D = q_local.size(3); - - const uint64_t* peer_ptrs_q = (const uint64_t*)peer_ptrs_q_tensor.data_ptr(); - const uint64_t* peer_ptrs_k = (const uint64_t*)peer_ptrs_k_tensor.data_ptr(); - - int threads = 256; - int num_elements = B * S_local * H * D; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (D % 16 == 0) { - int blocks = (num_elements / 8 + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - rope_multicast_kernel<8><<>>( - (const uint8_t*)q_local.data_ptr(), (const uint8_t*)k_local.data_ptr(), - (const uint8_t*)cos_local.data_ptr(), (const uint8_t*)sin_local.data_ptr(), - mcast_q, peer_ptrs_q, mcast_k, peer_ptrs_k, - B, S_local, H, D, world_size, rank - ); - } else if (D % 8 == 0) { - int blocks = (num_elements / 4 + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - rope_multicast_kernel<4><<>>( - (const uint8_t*)q_local.data_ptr(), (const uint8_t*)k_local.data_ptr(), - (const uint8_t*)cos_local.data_ptr(), (const uint8_t*)sin_local.data_ptr(), - mcast_q, peer_ptrs_q, mcast_k, peer_ptrs_k, - B, S_local, H, D, world_size, rank - ); - } else if (D % 4 == 0) { - int blocks = (num_elements / 2 + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - rope_multicast_kernel<2><<>>( - (const uint8_t*)q_local.data_ptr(), (const uint8_t*)k_local.data_ptr(), - (const uint8_t*)cos_local.data_ptr(), (const uint8_t*)sin_local.data_ptr(), - mcast_q, peer_ptrs_q, mcast_k, peer_ptrs_k, - B, S_local, H, D, world_size, rank - ); - } else { - int blocks = (num_elements + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - rope_multicast_kernel<1><<>>( - (const uint8_t*)q_local.data_ptr(), (const uint8_t*)k_local.data_ptr(), - (const uint8_t*)cos_local.data_ptr(), (const uint8_t*)sin_local.data_ptr(), - mcast_q, peer_ptrs_q, mcast_k, peer_ptrs_k, - B, S_local, H, D, world_size, rank - ); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_rope_multicast", &launch_rope_multicast, "Fused RoPE and Multicast All-Gather"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("rope_multicast_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(shape, dtype, device): - key = (shape, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf_q = symm_mem.empty(shape, dtype=dtype, device=device) - hdl_q = symm_mem.rendezvous(buf_q, dist.group.WORLD) - - buf_k = symm_mem.empty(shape, dtype=dtype, device=device) - hdl_k = symm_mem.rendezvous(buf_k, dist.group.WORLD) - - ptrs_q = torch.tensor(hdl_q.buffer_ptrs, device=device, dtype=torch.int64) - ptrs_k = torch.tensor(hdl_k.buffer_ptrs, device=device, dtype=torch.int64) - - mcast_q = int(hdl_q.multicast_ptr) if hasattr(hdl_q, 'multicast_ptr') and hdl_q.multicast_ptr else 0 - mcast_k = int(hdl_k.multicast_ptr) if hasattr(hdl_k, 'multicast_ptr') and hdl_k.multicast_ptr else 0 - - res = (buf_q, hdl_q, ptrs_q, mcast_q, buf_k, hdl_k, ptrs_k, mcast_k) - _symm_cache[key] = res - return res - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - half_dim = x.shape[-1] // 2 - x1, x2 = x[..., :half_dim], x[..., half_dim:] - return torch.cat((-x2, x1), dim=-1) - -@torch.no_grad() -def solution( - q_local: torch.Tensor, - k_local: torch.Tensor, - cos_local: torch.Tensor, - sin_local: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - - if not dist.is_initialized(): - cos = cos_local.unsqueeze(2) - sin = sin_local.unsqueeze(2) - q_embed_local = (q_local * cos) + (rotate_half(q_local) * sin) - k_embed_local = (k_local * cos) + (rotate_half(k_local) * sin) - return q_embed_local, k_embed_local - - world_size = dist.get_world_size() - rank = dist.get_rank() - - if rank == 0: - _get_ext() - dist.barrier() - - B, S_local, H, D = q_local.shape - global_shape = (B, S_local * world_size, H, D) - - buf_q, hdl_q, ptrs_q, mcast_q, buf_k, hdl_k, ptrs_k, mcast_k = _get_symm_state(global_shape, q_local.dtype, q_local.device) - - # Isolate cross-GPU traffic dependencies via lightweight symmetric memory device barriers - hdl_q.barrier(channel=0) - hdl_k.barrier(channel=0) - - _get_ext().launch_rope_multicast( - q_local.contiguous(), k_local.contiguous(), - cos_local.contiguous(), sin_local.contiguous(), - mcast_q, ptrs_q, mcast_k, ptrs_k, - world_size, rank - ) - - hdl_q.barrier(channel=0) - hdl_k.barrier(channel=0) - - return buf_q.clone(), buf_k.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_cuda.py deleted file mode 100755 index 893807b..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/18_tp_rms_norm_cuda.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -Strategy: -- **Kernel Fusion**: Combines local sum-of-squares computation, all-reduce communication, and scaling/normalization into a single optimized CUDA kernel. This prevents repeated round-trips to HBM for the massive hidden-states tensor. -- **Device-Side Communication via UVA**: Bypasses heavy NCCL collectives by using `torch.distributed._symmetric_memory` to allocate P2P-accessible buffers over NVLink. The kernel directly writes local sums and sequence flags into peer memory. -- **Compute-Communication Overlap**: Employs a persistent thread-block grid where each block handles independent rows. Thread 0 busy-waits on peer flags to fetch global sums, overlapping the fast intra-row communication directly with adjacent row computations and avoiding global barrier deadlocks. -- **Hardware Barrier Safety**: Utilizes `hdl.barrier(channel=0)`—a fast device-side hardware stream sync—before kernel launch to safely reuse persistent sync buffers across calls without risking race conditions or Python-side serialization delays. -- **Vectorized Memory Access**: Checks alignment dynamically to employ `uint4` (128-bit) vectorized loads and stores when processing `bfloat16` data, effectively doubling memory throughput. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void fused_rmsnorm_kernel( - const __nv_bfloat16* __restrict__ input, - const __nv_bfloat16* __restrict__ weight, - __nv_bfloat16* __restrict__ output, - const uint64_t* __restrict__ peer_sums_ptrs, - const uint64_t* __restrict__ peer_flags_ptrs, - int M, - int local_hidden_size, - float epsilon, - int world_size, - int rank, - uint64_t seq, - int global_hidden_size, - bool aligned -) { - for (int row = blockIdx.x; row < M; row += gridDim.x) { - const __nv_bfloat16* row_in = input + row * local_hidden_size; - __nv_bfloat16* row_out = output + row * local_hidden_size; - - float local_sum = 0.0f; - int tail_start = 0; - - if (aligned) { - int col = threadIdx.x * 8; - int limit = local_hidden_size; // guaranteed multiple of 8 if aligned - - for (; col < limit; col += blockDim.x * 8) { - uint4 vals = *(reinterpret_cast(&row_in[col])); - __nv_bfloat162* halfs = reinterpret_cast<__nv_bfloat162*>(&vals); - - #pragma unroll - for (int i = 0; i < 4; i++) { - float2 f2 = __bfloat1622float2(halfs[i]); - local_sum += f2.x * f2.x; - local_sum += f2.y * f2.y; - } - } - tail_start = limit; - } - - for (int c = tail_start + threadIdx.x; c < local_hidden_size; c += blockDim.x) { - float val = __bfloat162float(row_in[c]); - local_sum += val * val; - } - - // Warp block reduce local_sum - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } - - __shared__ float s_data[32]; - int lane = threadIdx.x % 32; - int warp = threadIdx.x / 32; - - if (lane == 0) { - s_data[warp] = local_sum; - } - __syncthreads(); - - if (warp == 0) { - float val = (lane < (blockDim.x / 32)) ? s_data[lane] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (lane == 0) { - float* my_sums = reinterpret_cast(peer_sums_ptrs[rank]); - uint64_t* my_flags = reinterpret_cast(peer_flags_ptrs[rank]); - - my_sums[row] = val; - asm volatile("fence.acq_rel.sys;" ::: "memory"); - asm volatile("st.global.release.sys.b64 [%0], %1;" :: "l"(&my_flags[row]), "l"(seq) : "memory"); - } - } - __syncthreads(); - - // Cross-GPU sync and summation - __shared__ float s_global_sum; - if (threadIdx.x == 0) { - float g_sum = 0.0f; - for (int p = 0; p < world_size; ++p) { - float* peer_sums = reinterpret_cast(peer_sums_ptrs[p]); - uint64_t* peer_flags = reinterpret_cast(peer_flags_ptrs[p]); - - uint64_t flag_val = 0; - do { - asm volatile("ld.global.acquire.sys.b64 %0, [%1];" : "=l"(flag_val) : "l"(&peer_flags[row]) : "memory"); - } while (flag_val != seq); - - g_sum += peer_sums[row]; - } - s_global_sum = g_sum; - } - __syncthreads(); - - float global_sum = s_global_sum; - float variance = global_sum / static_cast(global_hidden_size); - float rsqrt_var = rsqrtf(variance + epsilon); - - // Scale and Output - if (aligned) { - int col = threadIdx.x * 8; - int limit = local_hidden_size; - - for (; col < limit; col += blockDim.x * 8) { - uint4 in_vals = *(reinterpret_cast(&row_in[col])); - uint4 w_vals = *(reinterpret_cast(&weight[col])); - - __nv_bfloat162* in_halfs = reinterpret_cast<__nv_bfloat162*>(&in_vals); - __nv_bfloat162* w_halfs = reinterpret_cast<__nv_bfloat162*>(&w_vals); - - uint4 out_vals; - __nv_bfloat162* out_halfs = reinterpret_cast<__nv_bfloat162*>(&out_vals); - - #pragma unroll - for (int i = 0; i < 4; i++) { - float2 f_in = __bfloat1622float2(in_halfs[i]); - float2 f_w = __bfloat1622float2(w_halfs[i]); - - float2 f_out; - f_out.x = f_in.x * rsqrt_var * f_w.x; - f_out.y = f_in.y * rsqrt_var * f_w.y; - out_halfs[i] = __float22bfloat162_rn(f_out); - } - *(reinterpret_cast(&row_out[col])) = out_vals; - } - } else { - for (int c = threadIdx.x; c < local_hidden_size; c += blockDim.x) { - float val = __bfloat162float(row_in[c]); - float w = __bfloat162float(weight[c]); - row_out[c] = __float2bfloat16(val * rsqrt_var * w); - } - } - } -} - -void launch_fused_rmsnorm( - torch::Tensor input, - torch::Tensor weight, - torch::Tensor output, - torch::Tensor peer_sums_ptrs, - torch::Tensor peer_flags_ptrs, - float epsilon, - int world_size, - int rank, - int64_t seq -) { - int M = input.numel() / input.size(-1); - int local_hidden_size = input.size(-1); - int global_hidden_size = local_hidden_size * world_size; - - bool aligned = (local_hidden_size % 8 == 0) && - ((uintptr_t)input.data_ptr() % 16 == 0) && - ((uintptr_t)weight.data_ptr() % 16 == 0) && - ((uintptr_t)output.data_ptr() % 16 == 0); - - int threads = 256; - // Cap blocks to guarantee thread co-residency, thus preventing deadlocks on persistent flags - int blocks = M < 128 ? M : 128; - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - fused_rmsnorm_kernel<<>>( - reinterpret_cast(input.data_ptr()), - reinterpret_cast(weight.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), - reinterpret_cast(peer_sums_ptrs.data_ptr()), - reinterpret_cast(peer_flags_ptrs.data_ptr()), - M, - local_hidden_size, - epsilon, - world_size, - rank, - static_cast(seq), - global_hidden_size, - aligned - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_rmsnorm", &launch_fused_rmsnorm, "Fused RMSNorm with P2P allreduce"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_rmsnorm_ext", CUDA_SRC) - return _ext - - -class SymmMemManager: - def __init__(self): - self.cache = {} - - def get_buffers(self, M, device): - if M in self.cache: - return self.cache[M] - - # Calculate offsets considering 8-byte alignment bounds for uint64 flags - sums_bytes = (M * 4 + 7) & ~7 - flags_bytes = M * 8 - total_bytes = sums_bytes + flags_bytes - - buf = symm_mem.empty(total_bytes, dtype=torch.int8, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - sums_ptrs = [] - flags_ptrs = [] - for p in hdl.buffer_ptrs: - sums_ptrs.append(p) - flags_ptrs.append(p + sums_bytes) - - sums_ptrs_t = torch.tensor(sums_ptrs, dtype=torch.int64, device=device) - flags_ptrs_t = torch.tensor(flags_ptrs, dtype=torch.int64, device=device) - - self.cache[M] = (buf, hdl, sums_ptrs_t, flags_ptrs_t) - return self.cache[M] - -_symm_manager = SymmMemManager() -_seq_counter = 1 - - -@torch.no_grad() -def solution(local_hidden_states: torch.Tensor, local_weight: torch.Tensor, variance_epsilon: float) -> torch.Tensor: - input_dtype = local_hidden_states.dtype - - # Pure PyTorch fallback if disconnected or handling alternate dtypes - if input_dtype != torch.bfloat16 or not dist.is_initialized() or dist.get_world_size() == 1: - fp32_states = local_hidden_states.to(torch.float32) - local_sum_squares = fp32_states.pow(2).sum(dim=-1, keepdim=True) - - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - if world_size > 1: - dist.all_reduce(local_sum_squares, op=dist.ReduceOp.SUM) - - global_hidden_size = local_hidden_states.shape[-1] * world_size - variance = local_sum_squares / global_hidden_size - out = local_hidden_states * torch.rsqrt(variance + variance_epsilon) - return local_weight * out.to(input_dtype) - - global _seq_counter - if _ext is None: - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - _get_ext() - - input_tensor = local_hidden_states.contiguous() - weight_tensor = local_weight.to(torch.bfloat16).contiguous() - output_tensor = torch.empty_like(input_tensor) - - M = input_tensor.numel() // input_tensor.size(-1) - rank = dist.get_rank() - world_size = dist.get_world_size() - - seq = _seq_counter - _seq_counter += 1 - - _, hdl, sums_ptrs_t, flags_ptrs_t = _symm_manager.get_buffers(M, input_tensor.device) - - # Device-side stream barrier enforcing order against previously resident uses of this persistent buffer - hdl.barrier(channel=0) - - _get_ext().launch_fused_rmsnorm( - input_tensor, - weight_tensor, - output_tensor, - sums_ptrs_t, - flags_ptrs_t, - float(variance_epsilon), - world_size, - rank, - seq - ) - - return output_tensor \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_cuda.py deleted file mode 100755 index f6cc0e8..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/19_blocked_fp8_quantize_cuda.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -Strategy: -1. **Fused Compute & Push**: We replace the separate Triton quantization and NCCL all-gather with a single custom CUDA kernel. Each GPU locally quantizes its BF16 data to FP8, computes the FP32 scales, and directly pushes the results to all peers' memory via NVSwitch and UVA pointers. -2. **Zero-Copy All-Gather**: Symmetric memory buffers are allocated for the *entire* global output tensors. Each rank pushes its computed chunks directly to its respective non-overlapping slice (`rank * local_numel`) in every peer's buffer. This completely eliminates write conflicts and the need for a secondary communication pass. -3. **Compute–Communication Overlap & Vectorization**: The kernel processes data in 512-element tiles via shared memory. Packed 128-bit (`uint4`) stores are issued over NVLink to maximize cross-device bandwidth. The hardware overlaps these stores with the arithmetic (abs max reduction, scaling) of subsequent loops. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Tuple -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -__global__ void quantize_and_push_kernel( - const __nv_bfloat16* __restrict__ x, - const long long* __restrict__ y_ptrs, - const long long* __restrict__ s_ptrs, - int64_t numel, - int64_t block_size, - int64_t local_y_offset, - int64_t local_s_offset, - int world_size -) { - int64_t chunk_idx = blockIdx.x; - int64_t num_chunks = numel / block_size; - int tile_size = 512; - - // Shared memory for quantizing a tile and block reductions - __shared__ uint8_t shared_fp8[512]; - __shared__ float shared_max[32]; - - for (int64_t c = chunk_idx; c < num_chunks; c += gridDim.x) { - int64_t base_idx = c * block_size; - - // 1. Find max abs for the block - float local_max = 0.0f; - for (int64_t i = threadIdx.x; i < block_size; i += blockDim.x) { - float val = __bfloat162float(x[base_idx + i]); - local_max = fmaxf(local_max, fabsf(val)); - } - - // Warp reduce max - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset)); - } - - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - - if (lane == 0) shared_max[wid] = local_max; - __syncthreads(); - - if (wid == 0) { - float val = (lane < (blockDim.x + 31) / 32) ? shared_max[lane] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); - } - if (lane == 0) shared_max[0] = val; - } - __syncthreads(); - - float max_val = shared_max[0]; - float scale = max_val / 448.0f; - if (scale == 0.0f) scale = 1.0f; // Prevent division by zero - float inv_scale = 1.0f / scale; - - // 2. Quantize and Push via UVA in tiles - for (int64_t tile_start = 0; tile_start < block_size; tile_start += tile_size) { - int64_t current_tile_size = min((int64_t)tile_size, block_size - tile_start); - - // Quantize tile into shared memory - for (int64_t i = threadIdx.x; i < current_tile_size; i += blockDim.x) { - float val = __bfloat162float(x[base_idx + tile_start + i]); - float q_val = val * inv_scale; - __nv_fp8_e4m3 fp8_val(q_val); - shared_fp8[i] = *(reinterpret_cast(&fp8_val)); - } - __syncthreads(); - - // Check if 16-byte aligned for uint4 vectorized stores - bool can_vectorize = ((local_y_offset + base_idx + tile_start) % 16 == 0); - - if (can_vectorize) { - int num_uint4 = current_tile_size / 16; - for (int i = threadIdx.x; i < num_uint4; i += blockDim.x) { - uint4 packed = reinterpret_cast(shared_fp8)[i]; - int64_t global_offset = local_y_offset + base_idx + tile_start + i * 16; - - #pragma unroll - for (int r = 0; r < world_size; ++r) { - uint8_t* y_ptr = reinterpret_cast(y_ptrs[r]); - reinterpret_cast(y_ptr + global_offset)[0] = packed; - } - } - - // Handle remainder if tile size is not a multiple of 16 - int remainder_start = num_uint4 * 16; - if (current_tile_size > remainder_start) { - for (int i = remainder_start + threadIdx.x; i < current_tile_size; i += blockDim.x) { - uint8_t val = shared_fp8[i]; - int64_t global_offset = local_y_offset + base_idx + tile_start + i; - - #pragma unroll - for (int r = 0; r < world_size; ++r) { - uint8_t* y_ptr = reinterpret_cast(y_ptrs[r]); - y_ptr[global_offset] = val; - } - } - } - } else { - // Scalar fallback if unaligned - for (int64_t i = threadIdx.x; i < current_tile_size; i += blockDim.x) { - uint8_t val = shared_fp8[i]; - int64_t global_offset = local_y_offset + base_idx + tile_start + i; - - #pragma unroll - for (int r = 0; r < world_size; ++r) { - uint8_t* y_ptr = reinterpret_cast(y_ptrs[r]); - y_ptr[global_offset] = val; - } - } - } - __syncthreads(); - } - - // 3. Write scale to symmetric memory - if (threadIdx.x == 0) { - int64_t global_s_idx = local_s_offset + c; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - float* s_ptr = reinterpret_cast(s_ptrs[r]); - s_ptr[global_s_idx] = scale; - } - } - } -} - -void launch_quantize_and_push( - torch::Tensor x, - torch::Tensor y_ptrs_tensor, - torch::Tensor s_ptrs_tensor, - int64_t block_size, - int64_t local_y_offset, - int64_t local_s_offset, - int world_size -) { - int64_t numel = x.numel(); - int threads = 256; - // Launch enough blocks for full SM occupancy - int blocks = std::min((int)(numel / block_size), 1024); - if (blocks == 0) blocks = 1; - - const __nv_bfloat16* d_x = reinterpret_cast(x.data_ptr()); - const long long* d_y_ptrs = reinterpret_cast(y_ptrs_tensor.data_ptr()); - const long long* d_s_ptrs = reinterpret_cast(s_ptrs_tensor.data_ptr()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - quantize_and_push_kernel<<>>( - d_x, d_y_ptrs, d_s_ptrs, numel, block_size, local_y_offset, local_s_offset, world_size - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_quantize_and_push", &launch_quantize_and_push, "Fused block quantize to FP8 and NVLink push to peers"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_fp8_quant_push", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(shape_y, shape_s, device): - key = (shape_y, shape_s, device) - if key in _symm_cache: - return _symm_cache[key] - - # Allocate full global tensors symmetrically - buf_y = symm_mem.empty(shape_y, dtype=torch.uint8, device=device) - hdl_y = symm_mem.rendezvous(buf_y, dist.group.WORLD) - - buf_s = symm_mem.empty(shape_s, dtype=torch.float32, device=device) - hdl_s = symm_mem.rendezvous(buf_s, dist.group.WORLD) - - # Track UVA pointers into tensors to feed the kernel array lookup - y_ptrs_tensor = torch.tensor(hdl_y.buffer_ptrs, dtype=torch.int64, device=device) - s_ptrs_tensor = torch.tensor(hdl_s.buffer_ptrs, dtype=torch.int64, device=device) - - res = (buf_y, y_ptrs_tensor, buf_s, s_ptrs_tensor) - _symm_cache[key] = res - return res - -@torch.no_grad() -def solution(local_tensor: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Fused Multi-GPU Block FP8 Quantization and All-Gather. - Executes a custom CUDA kernel that locally calculates block scales, quantizes to FP8, - and leverages UVA symmetric memory to push data over NVLink to all peers in one pass. - """ - assert local_tensor.is_contiguous(), "Input tensor must be contiguous" - assert local_tensor.size(-1) % block_size == 0, "Last dimension must be divisible by block_size" - - if local_tensor.dtype != torch.bfloat16: - local_tensor = local_tensor.to(torch.bfloat16) - - device = local_tensor.device - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - world_size = 1 - rank = 0 - - # Determine the post-concatenation global shapes - local_shape = list(local_tensor.shape) - global_shape_y = [world_size * local_shape[0]] + local_shape[1:] if len(local_shape) > 0 else [world_size] - - local_shape_s = list(local_shape) - local_shape_s[-1] = local_shape_s[-1] // block_size - global_shape_s = [world_size * local_shape_s[0]] + local_shape_s[1:] if len(local_shape_s) > 0 else [world_size] - - if world_size > 1: - buf_y, y_ptrs_tensor, buf_s, s_ptrs_tensor = _get_symm_state( - tuple(global_shape_y), tuple(global_shape_s), device - ) - else: - buf_y = torch.empty(tuple(global_shape_y), dtype=torch.uint8, device=device) - buf_s = torch.empty(tuple(global_shape_s), dtype=torch.float32, device=device) - y_ptrs_tensor = torch.tensor([buf_y.data_ptr()], dtype=torch.int64, device=device) - s_ptrs_tensor = torch.tensor([buf_s.data_ptr()], dtype=torch.int64, device=device) - - local_numel = local_tensor.numel() - local_s_numel = local_numel // block_size - - # Since concatenation happens along dim=0, global flattened offsets scale perfectly with rank - local_y_offset = rank * local_numel - local_s_offset = rank * local_s_numel - - # Launch purely fused compute and NVLink scatter pass - _get_ext().launch_quantize_and_push( - local_tensor, - y_ptrs_tensor, - s_ptrs_tensor, - block_size, - local_y_offset, - local_s_offset, - world_size - ) - - if world_size > 1: - # Guarantee writes emitted from the stream are physically visible to all peers before read - torch.cuda.current_stream().synchronize() - dist.barrier() - - y_global = buf_y.view(torch.float8_e4m3fn) - return y_global, buf_s \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_cuda.py deleted file mode 100755 index ae193e1..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/1_allreduce_cuda.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Strategy: -1. We use `torch.distributed._symmetric_memory` to allocate identical device buffers across ranks. -2. The buffer is safely padded to a multiple of 16 bytes (8 bf16 elements). This ensures we can unconditionally use the Hopper NVSwitch `multimem` (hardware-accelerated multicast and in-switch reduction) for all BF16 inputs, avoiding slow fallback paths. -3. The custom CUDA multimem kernel perfectly distributes the reduction workload across all GPUs. Each GPU loads a disjoint slice of the mapped multicast window using `multimem.ld_reduce.v4.bf16x2`, letting the NVSwitch execute the actual reductions automatically. The result is written back via `multimem.st.v4.f32` (multicast store). -4. A device-side grid barrier (`blockwise_barrier_acq_rel`) synchronizes execution globally, ensuring no rank exits the kernel or overwrites the symmetric buffer before the collective completely finishes. -5. A custom CUDA fallback with native template dispatch handles arbitrary numeric types (fp32, fp16, int32) using direct P2P memory access. This completely replaces standard PyTorch NCCL collectives on the hot path. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Blockwise barrier across symmetric signal pads -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1u); -} - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier_relaxed( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, int rank, int world_size) -{ - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_relaxed(send_addr); - wait_signal_relaxed(wait_addr); -} - -__device__ void blockwise_barrier_acq_rel( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, int rank, int world_size) -{ - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -// --------------------------------------------------------------------------- -// NVSwitch Multimem operations -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4( - const uint64_t* addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) -{ - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(addr) : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4( - const uint64_t* addr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) -{ - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, - int world_size, - int rank, - int block_stride) -{ - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (numel_128 + world_size - 1) / world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = (int64_t)block_id * block_stride; - block_start < numel_per_rank; - block_start += (int64_t)num_programs * block_stride) - { - const int64_t offset = block_start + tid; - if (offset >= numel_per_rank) continue; - - const int64_t idx = rank * numel_per_rank + offset; - if (idx < numel_128) { - uint64_t* ptrs = reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - } - - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -// --------------------------------------------------------------------------- -// Scalar fallback for any dtype -// --------------------------------------------------------------------------- - -template -__global__ void allreduce_fallback_kernel( - const long long* __restrict__ ptrs, - scalar_t* __restrict__ out, - int world_size, int64_t n) -{ - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - double sum = 0.0; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const scalar_t* src = (const scalar_t*)ptrs[r]; - sum += static_cast(src[idx]); - } - out[idx] = static_cast(sum); - } -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, - torch::Tensor signal_pad_ptrs_tensor, - int64_t numel_128, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride) -{ - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, d_signal, numel_128, world_size, rank, block_stride); -} - -void launch_allreduce_fallback( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t n) -{ - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, out.scalar_type(), "allreduce_fallback", [&] { - allreduce_fallback_kernel<<>>( - d_ptrs, out.data_ptr(), world_size, n); - }); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16); - m.def("launch_allreduce_fallback", &launch_allreduce_fallback); -} -''' - -_ext = None -_ext_compiled = False - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("allreduce_cuda_bf16_h100_ext", CUDA_SRC) - return _ext - -def _compile_ext(): - global _ext_compiled - if not _ext_compiled: - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - _get_ext() - _ext_compiled = True - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = 16 // 2 # 8 bf16 elements per 128-bit chunk - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - - block_size = 32 # Minimum 32 threads required for the kernel's blockwise barrier to support world_size <= 32 - while block_size < num_threads and block_size < 1024: - block_size *= 2 - - if num_threads <= 1024: - num_blocks = 1 - else: - num_blocks = min((num_threads + 1024 - 1) // 1024, 4) - - return num_blocks, block_size, block_size - -_resource_cache = {} - -def _get_resources(numel: int, dtype: torch.dtype, device: torch.device): - key = (numel, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - pad_numel = (numel + 7) // 8 * 8 - buf = symm_mem.empty(pad_numel, device=device, dtype=dtype) - buf.zero_() - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - out = torch.empty(numel, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, out, ptrs_tensor, pad_numel) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - if not dist.is_initialized(): - return tensor.clone() - - input_tensor = tensor.contiguous() - n = input_tensor.numel() - if n == 0: - return input_tensor.clone() - - dtype = input_tensor.dtype - - if not _ext_compiled: - _compile_ext() - - buf, hdl, out, ptrs_tensor, pad_numel = _get_resources(n, dtype, input_tensor.device) - - buf[:n].copy_(input_tensor.view(-1)) - if pad_numel > n: - buf[n:].zero_() - - # Stream-ordered device barrier ensures safe visibility of all copies to peers before kernel starts - hdl.barrier(channel=0) - - if dtype == torch.bfloat16: - numel_128 = pad_numel // 8 - num_blocks, block_size, block_stride = _multimem_launch_config(pad_numel, hdl.world_size) - - multicast_ptr = int(hdl.multicast_ptr) - signal_dev = hdl.signal_pad_ptrs_dev - - _get_ext().launch_multimem_allreduce_bf16( - multicast_ptr, signal_dev, numel_128, hdl.world_size, - hdl.rank, num_blocks, block_size, block_stride - ) - # The blockwise_barrier in the kernel natively guarantees completion, no extra barrier required - return buf[:n].view(input_tensor.shape).clone() - else: - _get_ext().launch_allreduce_fallback(ptrs_tensor, out, n) - # Add post-kernel barrier to avoid immediate next-iteration overwrites of `buf` - hdl.barrier(channel=0) - return out.view(input_tensor.shape) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_cuda.py deleted file mode 100755 index dd1493c..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/20_blocked_fp8_dequantize_cuda.py +++ /dev/null @@ -1,322 +0,0 @@ -import math -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension -import triton -import triton.language as tl - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// ------------------------------------------------------------------------- -// 1-byte element kernel (FP8, int8, uint8) -// ------------------------------------------------------------------------- -template -__global__ void dequant_alltoall_kernel_1byte( - const uint8_t* __restrict__ y_ptr, - const float* __restrict__ s_ptr, - const int64_t* __restrict__ peer_ptrs, - int64_t chunk_numel, - int blocks_per_chunk, - int block_size, - int rank, - int world_size -) { - int64_t block_idx = blockIdx.x; - int dest_rank = block_idx / blocks_per_chunk; - int chunk_block_idx = block_idx % blocks_per_chunk; - - float* dest_ptr = reinterpret_cast(peer_ptrs[dest_rank]); - - // We write to the rank-th chunk of the destination buffer - int64_t dest_offset = (int64_t)rank * chunk_numel + (int64_t)chunk_block_idx * block_size; - int64_t src_offset = block_idx * block_size; - - float scale = s_ptr[block_idx]; - - int tid = threadIdx.x; - int stride = blockDim.x; - - if (block_size % 16 == 0) { - int num_vec = block_size / 16; - for (int i = tid; i < num_vec; i += stride) { - uint4 packed_y = *(reinterpret_cast(y_ptr + src_offset + i * 16)); - - uint8_t bytes[16]; - *(reinterpret_cast(bytes)) = packed_y; - - float out[16]; - #pragma unroll - for (int j = 0; j < 16; ++j) { - T val; - memcpy(&val, &bytes[j], 1); - out[j] = static_cast(val) * scale; - } - - float4* dest_float4 = reinterpret_cast(dest_ptr + dest_offset + i * 16); - dest_float4[0] = make_float4(out[0], out[1], out[2], out[3]); - dest_float4[1] = make_float4(out[4], out[5], out[6], out[7]); - dest_float4[2] = make_float4(out[8], out[9], out[10], out[11]); - dest_float4[3] = make_float4(out[12], out[13], out[14], out[15]); - } - } else { - for (int i = tid; i < block_size; i += stride) { - T val; - memcpy(&val, y_ptr + src_offset + i, 1); - dest_ptr[dest_offset + i] = static_cast(val) * scale; - } - } -} - -// ------------------------------------------------------------------------- -// 2-byte element kernel (BFloat16, Float16 fallbacks) -// ------------------------------------------------------------------------- -template -__global__ void dequant_alltoall_kernel_2byte( - const uint16_t* __restrict__ y_ptr, - const float* __restrict__ s_ptr, - const int64_t* __restrict__ peer_ptrs, - int64_t chunk_numel, - int blocks_per_chunk, - int block_size, - int rank, - int world_size -) { - int64_t block_idx = blockIdx.x; - int dest_rank = block_idx / blocks_per_chunk; - int chunk_block_idx = block_idx % blocks_per_chunk; - - float* dest_ptr = reinterpret_cast(peer_ptrs[dest_rank]); - int64_t dest_offset = (int64_t)rank * chunk_numel + (int64_t)chunk_block_idx * block_size; - int64_t src_offset = block_idx * block_size; - - float scale = s_ptr[block_idx]; - - int tid = threadIdx.x; - int stride = blockDim.x; - - if (block_size % 8 == 0) { - int num_vec = block_size / 8; - for (int i = tid; i < num_vec; i += stride) { - uint4 packed_y = *(reinterpret_cast(y_ptr + src_offset + i * 8)); - - uint16_t bytes[8]; - *(reinterpret_cast(bytes)) = packed_y; - - float out[8]; - #pragma unroll - for (int j = 0; j < 8; ++j) { - T val; - memcpy(&val, &bytes[j], 2); - out[j] = static_cast(val) * scale; - } - - float4* dest_float4 = reinterpret_cast(dest_ptr + dest_offset + i * 8); - dest_float4[0] = make_float4(out[0], out[1], out[2], out[3]); - dest_float4[1] = make_float4(out[4], out[5], out[6], out[7]); - } - } else { - for (int i = tid; i < block_size; i += stride) { - T val; - memcpy(&val, y_ptr + src_offset + i, 2); - dest_ptr[dest_offset + i] = static_cast(val) * scale; - } - } -} - -// ------------------------------------------------------------------------- -// 4-byte element kernel (Float32 fallback) -// ------------------------------------------------------------------------- -template -__global__ void dequant_alltoall_kernel_4byte( - const float* __restrict__ y_ptr, - const float* __restrict__ s_ptr, - const int64_t* __restrict__ peer_ptrs, - int64_t chunk_numel, - int blocks_per_chunk, - int block_size, - int rank, - int world_size -) { - int64_t block_idx = blockIdx.x; - int dest_rank = block_idx / blocks_per_chunk; - int chunk_block_idx = block_idx % blocks_per_chunk; - - float* dest_ptr = reinterpret_cast(peer_ptrs[dest_rank]); - int64_t dest_offset = (int64_t)rank * chunk_numel + (int64_t)chunk_block_idx * block_size; - int64_t src_offset = block_idx * block_size; - - float scale = s_ptr[block_idx]; - - int tid = threadIdx.x; - int stride = blockDim.x; - - if (block_size % 4 == 0) { - int num_vec = block_size / 4; - for (int i = tid; i < num_vec; i += stride) { - float4 packed_y = *(reinterpret_cast(y_ptr + src_offset + i * 4)); - - float4 out; - out.x = packed_y.x * scale; - out.y = packed_y.y * scale; - out.z = packed_y.z * scale; - out.w = packed_y.w * scale; - - float4* dest_float4 = reinterpret_cast(dest_ptr + dest_offset + i * 4); - dest_float4[0] = out; - } - } else { - for (int i = tid; i < block_size; i += stride) { - dest_ptr[dest_offset + i] = y_ptr[src_offset + i] * scale; - } - } -} - -// ------------------------------------------------------------------------- -// Launcher Dispatch -// ------------------------------------------------------------------------- -void launch_dequant_alltoall( - torch::Tensor y, - torch::Tensor s, - torch::Tensor peer_ptrs_tensor, - int64_t chunk_numel, - int blocks_per_chunk, - int block_size, - int rank, - int world_size -) { - auto s_ptr = s.data_ptr(); - auto peer_ptrs = peer_ptrs_tensor.data_ptr(); - - int num_blocks = world_size * blocks_per_chunk; - - int threads = 128; - if (block_size % 16 == 0) { - int num_vec = block_size / 16; - if (num_vec <= 32) threads = 32; - else if (num_vec <= 64) threads = 64; - else if (num_vec <= 128) threads = 128; - else threads = 256; - } else { - if (block_size <= 32) threads = 32; - else if (block_size <= 64) threads = 64; - else if (block_size <= 128) threads = 128; - else threads = 256; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - auto scalar_type = y.scalar_type(); - - // Dispatch based on scalar type, natively supporting FP8 - if (scalar_type == torch::kFloat8_e4m3fn) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_1byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kFloat8_e5m2) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_1byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kInt8) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_1byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kUInt8) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_1byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kBFloat16) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_2byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kFloat16) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_2byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else if (scalar_type == torch::kFloat32) { - auto y_ptr = reinterpret_cast(y.data_ptr()); - dequant_alltoall_kernel_4byte<<>>( - y_ptr, s_ptr, peer_ptrs, chunk_numel, blocks_per_chunk, block_size, rank, world_size); - } else { - TORCH_CHECK(false, "Unsupported dtype for FP8 dequant_alltoall."); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_dequant_alltoall", &launch_dequant_alltoall, "Fused FP8 dequantize and alltoall via symmetric memory"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_dequant_alltoall_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(shape, dtype, device): - global _symm_cache - key = (tuple(shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - n = math.prod(shape) - buf = symm_mem.empty(n, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return _symm_cache[key] - -@torch.no_grad() -def solution( - local_y: torch.Tensor, - local_s: torch.Tensor, - block_size: int = 128, -) -> torch.Tensor: - if local_y.numel() == 0: - return torch.empty_like(local_y, dtype=torch.float32) - - assert dist.is_initialized(), "torch.distributed must be initialized" - world_size = dist.get_world_size() - rank = dist.get_rank() - - # Determine chunking invariants mapped directly into the CUDA dispatch logic - chunk_numel = local_y.numel() // world_size - blocks_per_chunk = chunk_numel // block_size - - # Isolate compilation serialization - if rank == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - - # Set up symmetric memory arrays returning expected Float32 dtype - output_shape = local_y.shape - out_dtype = torch.float32 - buf, hdl, ptrs_tensor = _get_symm_state(output_shape, out_dtype, local_y.device) - - # Secure the symmetric buffer is safe to write remotely - hdl.barrier(channel=0) - - # Launch optimized UVA-mediated direct writes - ext.launch_dequant_alltoall( - local_y, local_s, ptrs_tensor, - chunk_numel, blocks_per_chunk, block_size, - rank, world_size - ) - - # Assure completion of all peer-issued NVLink transfers mapping into rank local buf memory - hdl.barrier(channel=0) - - # Return a new tensor ensuring isolation from the subsequent cache lifecycle updates - out = buf.view(output_shape).clone() - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_cuda.py deleted file mode 100755 index d29d483..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/21_clip_grad_norm_no_ep_cuda.py +++ /dev/null @@ -1,287 +0,0 @@ -""" -Optimized L2 clip_grad_norm for FSDP using custom CUDA and Symmetric Memory. - -Strategy: -1. Device-Side Communication: Replaced `dist.all_reduce` with a direct UVA read - of peer memory. We accumulate each rank's local sum of squares into a 1-float - symmetric memory buffer, then one kernel computes the global sum directly - across NVLink via UVA pointers. -2. Complete Compute-Communication Overlap: The entire process (squaring, - summation, global norm compute, and scaling) is pushed to the GPU stream. - There are no CPU syncs (no `.item()` calls). The scaling kernel is launched - asynchronously, conditionally executing only if `total_norm > max_norm`. - Launch overhead is minimized by traversing and launching tensors inside C++. -""" - -import math -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Kernel 1: Local norm squared accumulation -// --------------------------------------------------------------------------- -template -__global__ void add_norm_sq_kernel(const T* __restrict__ data, float* acc, int64_t numel) { - float local_sum = 0.0f; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel; idx += (int64_t)gridDim.x * blockDim.x) { - float val = static_cast(data[idx]); - local_sum += val * val; - } - - // Warp reduce - unsigned int mask = 0xffffffff; - for (int offset = 16; offset > 0; offset /= 2) { - local_sum += __shfl_down_sync(mask, local_sum, offset); - } - - // Block reduce - __shared__ float shared_sum[32]; - int lane = threadIdx.x % 32; - int warp_id = threadIdx.x / 32; - if (lane == 0) { - shared_sum[warp_id] = local_sum; - } - __syncthreads(); - - if (threadIdx.x < blockDim.x / 32) { - local_sum = shared_sum[threadIdx.x]; - } else { - local_sum = 0.0f; - } - - // Final warp reduce on the shared sums - if (warp_id == 0) { - for (int offset = 16; offset > 0; offset /= 2) { - local_sum += __shfl_down_sync(mask, local_sum, offset); - } - if (threadIdx.x == 0) { - atomicAdd(acc, local_sum); - } - } -} - -void compute_local_norm_sq(std::vector tensors, int64_t buf_ptr) { - float* acc = reinterpret_cast(buf_ptr); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - // Asynchronously zero the local symmetric memory accumulator - cudaMemsetAsync(acc, 0, sizeof(float), stream); - - for (const auto& t : tensors) { - if (!t.defined() || t.numel() == 0) continue; - int64_t numel = t.numel(); - int threads = 256; - int blocks = std::min((int)((numel + threads - 1) / threads), 1024); - - if (t.dtype() == torch::kBFloat16) { - add_norm_sq_kernel<<>>( - t.data_ptr(), acc, numel); - } else if (t.dtype() == torch::kFloat32) { - add_norm_sq_kernel<<>>( - t.data_ptr(), acc, numel); - } else if (t.dtype() == torch::kFloat16) { - add_norm_sq_kernel<<>>( - t.data_ptr(), acc, numel); - } - } -} - -// --------------------------------------------------------------------------- -// Kernels 2 & 3: Read peers, global norm computation, and scaling -// --------------------------------------------------------------------------- -__global__ void compute_total_norm_kernel( - const int64_t* __restrict__ peer_ptrs, - float* __restrict__ out_total_norm, - int group_size -) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - float total_sq = 0.0f; - for (int i = 0; i < group_size; ++i) { - const float* peer_buf = reinterpret_cast(peer_ptrs[i]); - total_sq += *peer_buf; - } - *out_total_norm = sqrtf(total_sq); - } -} - -template -__global__ void scale_gradients_kernel( - T* __restrict__ data, - const float* __restrict__ total_norm_ptr, - float max_norm, - int64_t numel -) { - float total_norm = *total_norm_ptr; - // Condition is purely device-side, avoiding CPU-GPU synchronization entirely - if (total_norm > max_norm) { - float coef = max_norm / total_norm; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel; idx += (int64_t)gridDim.x * blockDim.x) { - float val = static_cast(data[idx]); - data[idx] = static_cast(val * coef); - } - } -} - -void compute_global_norm_and_scale( - std::vector tensors, - at::Tensor peer_ptrs_tensor, - float max_norm, - int group_size, - at::Tensor out_total_norm -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const int64_t* peer_ptrs = peer_ptrs_tensor.data_ptr(); - float* out_norm = out_total_norm.data_ptr(); - - // One thread accumulates all peers' partials over P2P UVA pointers - compute_total_norm_kernel<<<1, 1, 0, stream>>>(peer_ptrs, out_norm, group_size); - - for (const auto& t : tensors) { - if (!t.defined() || t.numel() == 0) continue; - int64_t numel = t.numel(); - int threads = 256; - int blocks = std::min((int)((numel + threads - 1) / threads), 1024); - - if (t.dtype() == torch::kBFloat16) { - scale_gradients_kernel<<>>( - t.data_ptr(), out_norm, max_norm, numel); - } else if (t.dtype() == torch::kFloat32) { - scale_gradients_kernel<<>>( - t.data_ptr(), out_norm, max_norm, numel); - } else if (t.dtype() == torch::kFloat16) { - scale_gradients_kernel<<>>( - t.data_ptr(), out_norm, max_norm, numel); - } - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("compute_local_norm_sq", &compute_local_norm_sq, "Compute local sum of squares"); - m.def("compute_global_norm_and_scale", &compute_global_norm_and_scale, "Compute global norm and scale in-place"); -} -''' - -_ext = None - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("clip_grad_norm_uva_ext", CUDA_SRC) - return _ext - - -_symm_cache = {} - - -def _get_symm_state(device: torch.device, group: Optional[dist.ProcessGroup]): - group_id = id(group) if group is not None else 0 - if group_id in _symm_cache: - return _symm_cache[group_id] - - buf = symm_mem.empty((1,), dtype=torch.float32, device=device) - hdl = symm_mem.rendezvous(buf, group=group if group is not None else dist.group.WORLD) - peer_ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - state = (buf, hdl, peer_ptrs) - _symm_cache[group_id] = state - return state - - -def fallback_solution( - grad_tensors: List[torch.Tensor], - max_norm: float, - norm_type: float, - fsdp_group: Optional[dist.ProcessGroup] -) -> torch.Tensor: - """Stock PyTorch fallback path.""" - p = float(norm_type) - dev = None - acc = None - for g in grad_tensors: - if g is None: - continue - if dev is None: - dev = g.device - acc = torch.tensor(0.0, device=dev, dtype=torch.float32) - gn = torch.norm(g.detach().to(torch.float32), p=p) - acc = acc + (gn ** p) - - if acc is None: - acc = torch.tensor(0.0, device=torch.device("cuda", torch.cuda.current_device()), dtype=torch.float32) - - if fsdp_group is not None: - dist.all_reduce(acc, op=dist.ReduceOp.SUM, group=fsdp_group) - elif dist.is_initialized(): - dist.all_reduce(acc, op=dist.ReduceOp.SUM) - - total_norm = acc ** (1.0 / p) - - if total_norm > max_norm: - coef = max_norm / total_norm - for t in grad_tensors: - if t is not None: - t.mul_(coef.to(t.device)) - - return total_norm - - -@torch.no_grad() -def solution( - grad_tensors: List[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - fsdp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - Computes global L2 norm across all ranks locally & distributed, then scales. - Zero-overhead FSDP L2 norm clipping using symmetric memory & UVA buffers. - """ - if not dist.is_initialized() or float(norm_type) != 2.0: - return fallback_solution(grad_tensors, max_norm, norm_type, fsdp_group) - - valid_tensors = [t for t in grad_tensors if t is not None] - - device = valid_tensors[0].device if valid_tensors else torch.device("cuda", torch.cuda.current_device()) - - buf, hdl, peer_ptrs = _get_symm_state(device, fsdp_group) - out_total_norm = torch.empty((), dtype=torch.float32, device=device) - - ext = _get_ext() - - # Kernel 1: Calculate local sq norm asynchronously in 1 float - ext.compute_local_norm_sq(valid_tensors, buf.data_ptr()) - - # Memory wall 0: Ensures peers finished updating their local symmetric buffers - hdl.barrier(channel=0) - - # Kernel 2 + 3: Sum the buffers over UVA, calculate root and scale asynchronously - ext.compute_global_norm_and_scale( - valid_tensors, - peer_ptrs, - float(max_norm), - len(hdl.buffer_ptrs), - out_total_norm - ) - - # Memory wall 1: Ensures peers read this rank's symmetric buffer - # before returning and allowing a new iteration to memset it. - hdl.barrier(channel=1) - - return out_total_norm \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_cuda.py deleted file mode 100755 index 869409b..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/22_clip_grad_norm_ep_cuda.py +++ /dev/null @@ -1,379 +0,0 @@ -import math -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -__global__ void fused_scale_norm_kernel( - const void** __restrict__ ptrs, - const int64_t* __restrict__ sizes, - const int* __restrict__ dtypes, - int num_tensors, - float scale, - float* __restrict__ out_sum -) { - extern __shared__ float sdata[]; - float sum = 0.0f; - int tid = threadIdx.x; - - for (int t = 0; t < num_tensors; t++) { - int64_t size = sizes[t]; - int dtype = dtypes[t]; - - if (dtype == 0) { - __nv_bfloat16* data = (__nv_bfloat16*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - float val = __bfloat162float(data[i]); - if (scale != 1.0f) { - val *= scale; - data[i] = __float2bfloat16(val); - } - sum += val * val; - } - } else if (dtype == 1) { - float* data = (float*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - float val = data[i]; - if (scale != 1.0f) { - val *= scale; - data[i] = val; - } - sum += val * val; - } - } else if (dtype == 2) { - __half* data = (__half*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - float val = __half2float(data[i]); - if (scale != 1.0f) { - val *= scale; - data[i] = __float2half(val); - } - sum += val * val; - } - } - } - - sdata[tid] = sum; - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s >>= 1) { - if (tid < s) { - sdata[tid] += sdata[tid + s]; - } - __syncthreads(); - } - - if (tid == 0) { - atomicAdd(out_sum, sdata[0]); - } -} - -__global__ void clip_scale_kernel( - const void** __restrict__ ptrs, - const int64_t* __restrict__ sizes, - const int* __restrict__ dtypes, - int num_tensors, - float scale -) { - int tid = threadIdx.x; - for (int t = 0; t < num_tensors; t++) { - int64_t size = sizes[t]; - int dtype = dtypes[t]; - - if (dtype == 0) { - __nv_bfloat16* data = (__nv_bfloat16*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - float val = __bfloat162float(data[i]); - data[i] = __float2bfloat16(val * scale); - } - } else if (dtype == 1) { - float* data = (float*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - data[i] *= scale; - } - } else if (dtype == 2) { - __half* data = (__half*)ptrs[t]; - for (int64_t i = blockIdx.x * blockDim.x + tid; i < size; i += gridDim.x * blockDim.x) { - float val = __half2float(data[i]); - data[i] = __float2half(val * scale); - } - } - } -} - -__global__ void uva_reduce_step1_kernel( - const uint64_t* __restrict__ symm_ptrs, - const int* __restrict__ fsdp_ranks, int num_fsdp, - const int* __restrict__ ep_fsdp_ranks, int num_ep_fsdp, - float* __restrict__ out_non_ep, - int my_rank -) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - float non_ep_total = 0.0f; - for (int i = 0; i < num_fsdp; i++) { - int r = fsdp_ranks[i]; - float* peer_buf = (float*)symm_ptrs[r]; - non_ep_total += peer_buf[0]; - } - - float ep_fsdp_total = 0.0f; - for (int i = 0; i < num_ep_fsdp; i++) { - int r = ep_fsdp_ranks[i]; - float* peer_buf = (float*)symm_ptrs[r]; - ep_fsdp_total += peer_buf[1]; - } - - float* my_buf = (float*)symm_ptrs[my_rank]; - my_buf[2] = ep_fsdp_total; - *out_non_ep = non_ep_total; - } -} - -__global__ void uva_reduce_step2_kernel( - const uint64_t* __restrict__ symm_ptrs, - const int* __restrict__ ep_ranks, int num_ep, - float* __restrict__ out_ep -) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - float ep_total = 0.0f; - for (int i = 0; i < num_ep; i++) { - int r = ep_ranks[i]; - float* peer_buf = (float*)symm_ptrs[r]; - ep_total += peer_buf[2]; - } - *out_ep = ep_total; - } -} - -void compute_norm_and_scale( - torch::Tensor ptrs_tensor, - torch::Tensor sizes_tensor, - torch::Tensor dtypes_tensor, - int num_tensors, - float scale, - torch::Tensor out_sum -) { - if (num_tensors == 0) return; - int threads = 256; - int blocks = std::max(1, std::min(1024, (int)num_tensors * 4)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - fused_scale_norm_kernel<<>>( - (const void**)ptrs_tensor.data_ptr(), - sizes_tensor.data_ptr(), - dtypes_tensor.data_ptr(), - num_tensors, - scale, - out_sum.data_ptr() - ); -} - -void apply_clip_scale( - torch::Tensor ptrs_tensor, - torch::Tensor sizes_tensor, - torch::Tensor dtypes_tensor, - int num_tensors, - float scale -) { - if (num_tensors == 0) return; - int threads = 256; - int blocks = std::max(1, std::min(1024, (int)num_tensors * 4)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - clip_scale_kernel<<>>( - (const void**)ptrs_tensor.data_ptr(), - sizes_tensor.data_ptr(), - dtypes_tensor.data_ptr(), - num_tensors, - scale - ); -} - -void launch_uva_reduce_step1( - torch::Tensor symm_ptrs, - torch::Tensor fsdp_ranks, - torch::Tensor ep_fsdp_ranks, - torch::Tensor out_non_ep, - int my_rank -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - uva_reduce_step1_kernel<<<1, 1, 0, stream>>>( - (const uint64_t*)symm_ptrs.data_ptr(), - fsdp_ranks.data_ptr(), fsdp_ranks.size(0), - ep_fsdp_ranks.data_ptr(), ep_fsdp_ranks.size(0), - out_non_ep.data_ptr(), - my_rank - ); -} - -void launch_uva_reduce_step2( - torch::Tensor symm_ptrs, - torch::Tensor ep_ranks, - torch::Tensor out_ep -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - uva_reduce_step2_kernel<<<1, 1, 0, stream>>>( - (const uint64_t*)symm_ptrs.data_ptr(), - ep_ranks.data_ptr(), ep_ranks.size(0), - out_ep.data_ptr() - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("compute_norm_and_scale", &compute_norm_and_scale); - m.def("apply_clip_scale", &apply_clip_scale); - m.def("launch_uva_reduce_step1", &launch_uva_reduce_step1); - m.def("launch_uva_reduce_step2", &launch_uva_reduce_step2); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("clip_grad_norm_uva_ext", CUDA_SRC) - return _ext - -_symm_state = None -def _get_symm_state(device): - global _symm_state - if _symm_state is not None: - return _symm_state - - # [non_ep_sum, ep_sum, ep_sum_round1, padding] - buf = symm_mem.empty(4, dtype=torch.float32, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - symm_ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_state = (buf, hdl, symm_ptrs) - return _symm_state - -_tensor_info_cache = {} -def _get_tensor_info_cached(tensors, key, device): - tensors = [t for t in tensors if t is not None] - if not tensors: - return None, 0 - - cache_key = (key, tuple((t.data_ptr(), t.numel(), t.dtype) for t in tensors)) - if cache_key in _tensor_info_cache: - return _tensor_info_cache[cache_key] - - if len(_tensor_info_cache) > 100: - _tensor_info_cache.clear() - - ptrs = [t.data_ptr() for t in tensors] - sizes = [t.numel() for t in tensors] - dtypes = [] - for t in tensors: - if t.dtype == torch.bfloat16: dtypes.append(0) - elif t.dtype == torch.float32: dtypes.append(1) - elif t.dtype == torch.float16: dtypes.append(2) - else: raise ValueError(f"Unsupported dtype: {t.dtype}") - - ptrs_t = torch.tensor(ptrs, dtype=torch.int64, device=device) - sizes_t = torch.tensor(sizes, dtype=torch.int64, device=device) - dtypes_t = torch.tensor(dtypes, dtype=torch.int32, device=device) - res = ((ptrs_t, sizes_t, dtypes_t), len(tensors)) - _tensor_info_cache[cache_key] = res - return res - -_ranks_cache = {} -def _get_ranks_tensor_cached(group, device): - if group is None: - ranks = [dist.get_rank()] if dist.is_initialized() else [0] - else: - ranks = dist.get_process_group_ranks(group) - - key = tuple(ranks) - if key not in _ranks_cache: - _ranks_cache[key] = torch.tensor(ranks, dtype=torch.int32, device=device) - return _ranks_cache[key] - -@torch.no_grad() -def solution( - non_ep_grad_tensors: List[torch.Tensor], - ep_grad_tensors: List[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, - ep_size: int = 1, - fsdp_group: Optional[dist.ProcessGroup] = None, - ep_fsdp_group: Optional[dist.ProcessGroup] = None, - ep_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - ext = _get_ext() - dev = next((t.device for t in non_ep_grad_tensors + ep_grad_tensors if t is not None), torch.device("cuda")) - - non_ep_info, non_ep_count = _get_tensor_info_cached(non_ep_grad_tensors, "non_ep", dev) - ep_info, ep_count = _get_tensor_info_cached(ep_grad_tensors, "ep", dev) - - if not dist.is_initialized(): - # Single GPU Fallback - buf = torch.zeros(2, dtype=torch.float32, device=dev) - if non_ep_count > 0: - ext.compute_norm_and_scale(non_ep_info[0], non_ep_info[1], non_ep_info[2], non_ep_count, 1.0, buf[0:1]) - if ep_count > 0: - ep_scale = 1.0 / float(ep_size) if ep_size > 1 else 1.0 - ext.compute_norm_and_scale(ep_info[0], ep_info[1], ep_info[2], ep_count, ep_scale, buf[1:2]) - - total_norm_tensor = torch.sqrt(buf[0] + buf[1]) - total_norm_val = total_norm_tensor.item() - if total_norm_val > max_norm: - coef = max_norm / total_norm_val - if non_ep_count > 0: - ext.apply_clip_scale(non_ep_info[0], non_ep_info[1], non_ep_info[2], non_ep_count, coef) - if ep_count > 0: - ext.apply_clip_scale(ep_info[0], ep_info[1], ep_info[2], ep_count, coef) - return total_norm_tensor - - buf, hdl, symm_ptrs = _get_symm_state(dev) - my_rank = dist.get_rank() - - fsdp_ranks = _get_ranks_tensor_cached(fsdp_group, dev) - ep_fsdp_ranks = _get_ranks_tensor_cached(ep_fsdp_group, dev) - ep_ranks = _get_ranks_tensor_cached(ep_group, dev) - - buf[:2].zero_() - - if non_ep_count > 0: - ext.compute_norm_and_scale(non_ep_info[0], non_ep_info[1], non_ep_info[2], non_ep_count, 1.0, buf[0:1]) - - if ep_count > 0: - ep_scale = 1.0 / float(ep_size) if ep_size > 1 else 1.0 - ext.compute_norm_and_scale(ep_info[0], ep_info[1], ep_info[2], ep_count, ep_scale, buf[1:2]) - - out_norms = torch.empty(2, dtype=torch.float32, device=dev) - - # Barrier 0: Blocks device stream until local sums are computed - hdl.barrier(channel=0) - - ext.launch_uva_reduce_step1(symm_ptrs, fsdp_ranks, ep_fsdp_ranks, out_norms[0:1], my_rank) - - # Barrier 1: Blocks device stream until step1 writes (ep_fsdp sub-totals) are visible - hdl.barrier(channel=1) - - ext.launch_uva_reduce_step2(symm_ptrs, ep_ranks, out_norms[1:2]) - - # Barrier 2: Ensures all streams have consumed the buffer before the next iteration can `zero_()` it - hdl.barrier(channel=2) - - total_norm_tensor = torch.sqrt(out_norms[0] + out_norms[1]) - total_norm_val = total_norm_tensor.item() - - if total_norm_val > max_norm: - coef = max_norm / total_norm_val - if non_ep_count > 0: - ext.apply_clip_scale(non_ep_info[0], non_ep_info[1], non_ep_info[2], non_ep_count, coef) - if ep_count > 0: - ext.apply_clip_scale(ep_info[0], ep_info[1], ep_info[2], ep_count, coef) - - return total_norm_tensor \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_cuda.py deleted file mode 100755 index ccaa7ab..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/23_grad_acc_loss_cuda.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Strategy: -For scalar (or very small) metrics like loss, the critical path is latency—specifically PyTorch elementwise kernel launches and NCCL host syncs. -By fusing the forward pass (NaN checks, multiplication, inter-GPU reduction, division) and the backward pass into a single custom UVA CUDA kernel, we eliminate all PyTorch overhead and NCCL host roundtrips. -We use `torch.distributed._symmetric_memory` to allocate symmetric buffers and custom signal pads. A single kernel reads inputs, exchanges data directly via NVLink using `atom.global.release.sys`/`acquire.sys` flip-flop barriers, and writes the final outputs. This maximizes compute-communication overlap by keeping the entire operation on-device without returning to the host. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Tuple, Optional -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Flip-flop barriers using system-wide acquire/release atomics for memory consistency across NVLink -__device__ __forceinline__ void send_signal(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.acquire.sys.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier( - const uint64_t* __restrict__ signal_pad_ptrs, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint32_t* remote_pad = reinterpret_cast(signal_pad_ptrs[flat_tid]); - uint32_t* send_addr = &remote_pad[rank]; - - uint32_t* local_pad = reinterpret_cast(signal_pad_ptrs[rank]); - uint32_t* wait_addr = &local_pad[flat_tid]; - - send_signal(send_addr); - wait_signal(wait_addr); -} - -template -struct CudaTypeTraits; - -template <> -struct CudaTypeTraits { - static __device__ __forceinline__ float to_float(float x) { return x; } - static __device__ __forceinline__ float from_float(float x) { return x; } -}; - -template <> -struct CudaTypeTraits<__nv_bfloat16> { - static __device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } - static __device__ __forceinline__ __nv_bfloat16 from_float(float x) { return __float2bfloat16(x); } -}; - -template <> -struct CudaTypeTraits<__half> { - static __device__ __forceinline__ float to_float(__half x) { return __half2float(x); } - static __device__ __forceinline__ __half from_float(float x) { return __float2half(x); } -}; - -template -__global__ void fused_loss_fw_bw_kernel( - const T* __restrict__ loss, - const T* __restrict__ local_valid_tokens, - const T* __restrict__ global_valid_tokens, - const T* __restrict__ grad_normalized_loss, - const T* __restrict__ grad_loss_sum, - T* __restrict__ normalized_loss_out, - T* __restrict__ loss_sum_out, - T* __restrict__ grad_loss_out, - const uint64_t* __restrict__ symm_buffer_ptrs, - const uint64_t* __restrict__ signal_pad_ptrs, - int rank, - int world_size, - int numel -) { - float lvt = CudaTypeTraits::to_float(local_valid_tokens[0]); - float gvt = CudaTypeTraits::to_float(global_valid_tokens[0]); - - for (int idx = threadIdx.x; idx < numel; idx += blockDim.x) { - float l = CudaTypeTraits::to_float(loss[idx]); - if (lvt == 0.0f) { - if (isnan(l) || isinf(l)) l = 0.0f; - } - float l_sum = l * lvt; - float* my_symm_buf = reinterpret_cast(symm_buffer_ptrs[rank]); - my_symm_buf[idx] = l_sum; - } - - __syncthreads(); - blockwise_barrier(signal_pad_ptrs, rank, world_size); - __syncthreads(); - - for (int idx = threadIdx.x; idx < numel; idx += blockDim.x) { - float total_l_sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - float* peer_symm_buf = reinterpret_cast(symm_buffer_ptrs[r]); - total_l_sum += peer_symm_buf[idx]; - } - - float norm_loss = total_l_sum / gvt; - float gnl = CudaTypeTraits::to_float(grad_normalized_loss[idx]); - float gls = 0.0f; - if (grad_loss_sum != nullptr) { - gls = CudaTypeTraits::to_float(grad_loss_sum[idx]); - } - - float grad_from_norm = gnl * lvt / gvt; - float grad_from_sum = gls * lvt; - float grad_l = grad_from_norm + grad_from_sum; - - normalized_loss_out[idx] = CudaTypeTraits::from_float(norm_loss); - loss_sum_out[idx] = CudaTypeTraits::from_float(total_l_sum); - grad_loss_out[idx] = CudaTypeTraits::from_float(grad_l); - } - - __syncthreads(); - blockwise_barrier(signal_pad_ptrs, rank, world_size); -} - -void launch_fused_loss( - torch::Tensor loss, - torch::Tensor local_valid_tokens, - torch::Tensor global_valid_tokens, - torch::Tensor grad_normalized_loss, - c10::optional grad_loss_sum, - torch::Tensor normalized_loss_out, - torch::Tensor loss_sum_out, - torch::Tensor grad_loss_out, - torch::Tensor symm_buffer_ptrs, - torch::Tensor signal_pad_ptrs, - int rank, - int world_size -) { - int numel = loss.numel(); - int threads = 256; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (loss.dtype() == torch::kBFloat16) { - fused_loss_fw_bw_kernel<__nv_bfloat16><<<1, threads, 0, stream>>>( - reinterpret_cast(loss.data_ptr()), - reinterpret_cast(local_valid_tokens.data_ptr()), - reinterpret_cast(global_valid_tokens.data_ptr()), - reinterpret_cast(grad_normalized_loss.data_ptr()), - grad_loss_sum.has_value() ? reinterpret_cast(grad_loss_sum.value().data_ptr()) : nullptr, - reinterpret_cast<__nv_bfloat16*>(normalized_loss_out.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(loss_sum_out.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(grad_loss_out.data_ptr()), - reinterpret_cast(symm_buffer_ptrs.data_ptr()), - reinterpret_cast(signal_pad_ptrs.data_ptr()), - rank, - world_size, - numel - ); - } else if (loss.dtype() == torch::kFloat32) { - fused_loss_fw_bw_kernel<<<1, threads, 0, stream>>>( - loss.data_ptr(), - local_valid_tokens.data_ptr(), - global_valid_tokens.data_ptr(), - grad_normalized_loss.data_ptr(), - grad_loss_sum.has_value() ? grad_loss_sum.value().data_ptr() : nullptr, - normalized_loss_out.data_ptr(), - loss_sum_out.data_ptr(), - grad_loss_out.data_ptr(), - reinterpret_cast(symm_buffer_ptrs.data_ptr()), - reinterpret_cast(signal_pad_ptrs.data_ptr()), - rank, - world_size, - numel - ); - } else if (loss.dtype() == torch::kFloat16) { - fused_loss_fw_bw_kernel<__half><<<1, threads, 0, stream>>>( - reinterpret_cast(loss.data_ptr()), - reinterpret_cast(local_valid_tokens.data_ptr()), - reinterpret_cast(global_valid_tokens.data_ptr()), - reinterpret_cast(grad_normalized_loss.data_ptr()), - grad_loss_sum.has_value() ? reinterpret_cast(grad_loss_sum.value().data_ptr()) : nullptr, - reinterpret_cast<__half*>(normalized_loss_out.data_ptr()), - reinterpret_cast<__half*>(loss_sum_out.data_ptr()), - reinterpret_cast<__half*>(grad_loss_out.data_ptr()), - reinterpret_cast(symm_buffer_ptrs.data_ptr()), - reinterpret_cast(signal_pad_ptrs.data_ptr()), - rank, - world_size, - numel - ); - } else { - TORCH_CHECK(false, "Unsupported dtype: only float32, float16, and bfloat16 are supported"); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_loss", &launch_fused_loss, "Fused loss fw bw with symm_mem allreduce"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - if dist.is_initialized(): - if dist.get_rank() == 0: - _ext = compile_cuda_extension("fused_loss_acc", CUDA_SRC) - dist.barrier() - if dist.get_rank() != 0: - _ext = compile_cuda_extension("fused_loss_acc", CUDA_SRC) - else: - _ext = compile_cuda_extension("fused_loss_acc", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(numel: int, device: torch.device, world_size: int): - key = (numel, device, world_size) - if key in _resource_cache: - return _resource_cache[key] - - buf = symm_mem.empty(numel, dtype=torch.float32, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - buf_ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - pad = symm_mem.empty((world_size,), dtype=torch.int32, device=device) - pad_hdl = symm_mem.rendezvous(pad, dist.group.WORLD) - pad.zero_() - pad_ptrs = torch.tensor(pad_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - # Guarantee that pad zeroes are globally visible before they are actively requested by kernels - dist.barrier() - - res = (buf, buf_ptrs, pad_ptrs) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution( - loss: torch.Tensor, - local_valid_tokens: torch.Tensor, - global_valid_tokens: torch.Tensor, - grad_normalized_loss: torch.Tensor, - grad_loss_sum: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if not dist.is_initialized() or dist.get_world_size() == 1: - if local_valid_tokens.item() == 0: - loss = torch.nan_to_num(loss) - loss_sum = loss * local_valid_tokens - normalized_loss = loss_sum / global_valid_tokens - - grad_from_norm = grad_normalized_loss * local_valid_tokens / global_valid_tokens - if grad_loss_sum is not None: - grad_from_sum = grad_loss_sum * local_valid_tokens - else: - grad_from_sum = torch.zeros_like(grad_normalized_loss) - - return normalized_loss, loss_sum, grad_from_norm + grad_from_sum - - loss = loss.contiguous() - local_valid_tokens = local_valid_tokens.contiguous() - global_valid_tokens = global_valid_tokens.contiguous() - grad_normalized_loss = grad_normalized_loss.contiguous() - if grad_loss_sum is not None: - grad_loss_sum = grad_loss_sum.contiguous() - - world_size = dist.get_world_size() - rank = dist.get_rank() - numel = loss.numel() - - ext = _get_ext() - buf, buf_ptrs, pad_ptrs = _get_resources(numel, loss.device, world_size) - - normalized_loss_out = torch.empty_like(loss) - loss_sum_out = torch.empty_like(loss) - grad_loss_out = torch.empty_like(loss) - - ext.launch_fused_loss( - loss, - local_valid_tokens, - global_valid_tokens, - grad_normalized_loss, - grad_loss_sum, - normalized_loss_out, - loss_sum_out, - grad_loss_out, - buf_ptrs, - pad_ptrs, - rank, - world_size - ) - - return normalized_loss_out, loss_sum_out, grad_loss_out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_cuda.py deleted file mode 100755 index 3c7fb0f..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/24_load_balancing_loss_fn_cuda.py +++ /dev/null @@ -1,455 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Union, Tuple, Optional -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// ------------------------------------------------------------------------- -// 1. Forward Kernel: Single pass to compute local M_e, C_e, and W_total -// ------------------------------------------------------------------------- -__global__ void moe_load_balance_warp_kernel( - const __nv_bfloat16* __restrict__ logits, - const float* __restrict__ mask, - float* __restrict__ global_m_e, - float* __restrict__ global_c_e, - float* __restrict__ global_w_total, - int total_tokens, - int mask_size, - int num_experts, - int top_k -) { - extern __shared__ float smem[]; - float* s_m_e = smem; - float* s_c_e = smem + num_experts; - float* s_w_total = smem + 2 * num_experts; - - for (int i = threadIdx.x; i < 2 * num_experts + 1; i += blockDim.x) { - smem[i] = 0.0f; - } - __syncthreads(); - - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - int token_idx = blockIdx.x * (blockDim.x / 32) + warp_id; - - if (token_idx < total_tokens) { - float w = 1.0f; - if (mask != nullptr && mask_size > 0) { - w = mask[token_idx % mask_size]; - } - - if (w > 0.0f) { - float max_val = -1e20f; - const __nv_bfloat16* token_logits = logits + token_idx * num_experts; - - // Pass 1: Max - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - if (val > max_val) max_val = val; - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - max_val = max(max_val, __shfl_down_sync(0xffffffff, max_val, offset)); - } - max_val = __shfl_sync(0xffffffff, max_val, 0); - - // Pass 2: Sum Exp - float sum_exp = 0.0f; - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - sum_exp += expf(val - max_val); - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum_exp += __shfl_down_sync(0xffffffff, sum_exp, offset); - } - sum_exp = __shfl_sync(0xffffffff, sum_exp, 0); - - // Pass 3: Softmax Probabilities and accumulation - float thread_probs[128]; // Safe up to 4096 experts per warp - int thread_indices[128]; - int num_items = 0; - - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - float prob = expf(val - max_val) / sum_exp; - if (num_items < 128) { - thread_probs[num_items] = prob; - thread_indices[num_items] = i; - num_items++; - } - atomicAdd(&s_c_e[i], w * prob); - } - - // Select Top-K with consistent tie-breaking behavior - for (int k = 0; k < top_k; k++) { - float local_max_prob = -1.0f; - int local_max_idx = -1; - for (int i = 0; i < num_items; i++) { - if (thread_probs[i] > local_max_prob) { - local_max_prob = thread_probs[i]; - local_max_idx = thread_indices[i]; - } - } - - float warp_max_prob = local_max_prob; - int warp_max_idx = local_max_idx; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - float other_prob = __shfl_down_sync(0xffffffff, warp_max_prob, offset); - int other_idx = __shfl_down_sync(0xffffffff, warp_max_idx, offset); - if (other_prob > warp_max_prob || (other_prob == warp_max_prob && other_idx < warp_max_idx)) { - warp_max_prob = other_prob; - warp_max_idx = other_idx; - } - } - - warp_max_prob = __shfl_sync(0xffffffff, warp_max_prob, 0); - warp_max_idx = __shfl_sync(0xffffffff, warp_max_idx, 0); - - if (lane_id == 0) { - atomicAdd(&s_m_e[warp_max_idx], w); - } - - // Mask out selected probability for the next K iteration - for (int i = 0; i < num_items; i++) { - if (thread_indices[i] == warp_max_idx) { - thread_probs[i] = -2.0f; - } - } - } - - if (lane_id == 0) atomicAdd(s_w_total, w); - } - } - - __syncthreads(); - - // Flush shared memory to global arrays - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { - if (s_m_e[i] > 0.0f) atomicAdd(&global_m_e[i], s_m_e[i]); - if (s_c_e[i] > 0.0f) atomicAdd(&global_c_e[i], s_c_e[i]); - } - if (threadIdx.x == 0) { - if (s_w_total[0] > 0.0f) atomicAdd(global_w_total, s_w_total[0]); - } -} - -// ------------------------------------------------------------------------- -// 2. Compute Local Loss -// ------------------------------------------------------------------------- -__global__ void compute_local_loss_kernel( - const float* __restrict__ m_e, - const float* __restrict__ c_e, - const float* __restrict__ w_total, - float* __restrict__ local_loss, - int num_experts -) { - float w = w_total[0]; - if (w <= 0.0f) w = 1.0f; // Avoid NaN division if completely masked out - - float sum = 0.0f; - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { - sum += (m_e[i] / w) * (c_e[i] / w); - } - - static __shared__ float shared_sum[32]; - int lane = threadIdx.x % 32; - int warp = threadIdx.x / 32; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - - if (lane == 0) shared_sum[warp] = sum; - __syncthreads(); - - if (warp == 0) { - sum = (lane < (blockDim.x / 32)) ? shared_sum[lane] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - if (lane == 0) { - local_loss[0] = sum * num_experts; - } - } -} - -// ------------------------------------------------------------------------- -// 3. UVA Symmetric Memory Scalar All-Reduce -// ------------------------------------------------------------------------- -__global__ void symm_allreduce_scalar_kernel( - const long long* __restrict__ peer_ptrs, - float* __restrict__ out_val, - int world_size -) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - float sum = 0.0f; - for (int i = 0; i < world_size; i++) { - const float* ptr = reinterpret_cast(peer_ptrs[i]); - sum += *ptr; - } - out_val[0] = sum / world_size; - } -} - -// ------------------------------------------------------------------------- -// 4. Backward Pass Kernel: Recompute probabilities and emit exact analytical grad -// ------------------------------------------------------------------------- -__global__ void moe_load_balance_backward_kernel( - const __nv_bfloat16* __restrict__ logits, - const float* __restrict__ mask, - const float* __restrict__ m_e, - float w_total, - __nv_bfloat16* __restrict__ grad_logits, - float grad_output_scaled, - int total_tokens, - int mask_size, - int num_experts -) { - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - int token_idx = blockIdx.x * (blockDim.x / 32) + warp_id; - - if (token_idx >= total_tokens) return; - - float w = 1.0f; - if (mask != nullptr && mask_size > 0) w = mask[token_idx % mask_size]; - - __nv_bfloat16* out_ptr = grad_logits + token_idx * num_experts; - - if (w <= 0.0f) { - for (int i = lane_id; i < num_experts; i += 32) out_ptr[i] = __float2bfloat16(0.0f); - return; - } - - float W_sq = w_total * w_total; - if (W_sq <= 0.0f) W_sq = 1.0f; - float scale = (float(num_experts) / W_sq) * grad_output_scaled * w; - - const __nv_bfloat16* token_logits = logits + token_idx * num_experts; - - float max_val = -1e20f; - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - if (val > max_val) max_val = val; - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) max_val = max(max_val, __shfl_down_sync(0xffffffff, max_val, offset)); - max_val = __shfl_sync(0xffffffff, max_val, 0); - - float sum_exp = 0.0f; - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - sum_exp += expf(val - max_val); - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) sum_exp += __shfl_down_sync(0xffffffff, sum_exp, offset); - sum_exp = __shfl_sync(0xffffffff, sum_exp, 0); - - // Compute Analytical Expected Grad component - float expected_G = 0.0f; - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - float p = expf(val - max_val) / sum_exp; - expected_G += p * m_e[i]; - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) expected_G += __shfl_down_sync(0xffffffff, expected_G, offset); - expected_G = __shfl_sync(0xffffffff, expected_G, 0); - - // Emit final grad components safely - for (int i = lane_id; i < num_experts; i += 32) { - float val = __bfloat162float(token_logits[i]); - float p = expf(val - max_val) / sum_exp; - float dx = scale * p * (m_e[i] - expected_G); - out_ptr[i] = __float2bfloat16(dx); - } -} - -// ------------------------------------------------------------------------- -// Bindings -// ------------------------------------------------------------------------- -void launch_moe_load_balance( - torch::Tensor logits, std::optional mask, - torch::Tensor m_e, torch::Tensor c_e, torch::Tensor w_total, int top_k -) { - int total_tokens = logits.size(0); - int num_experts = logits.size(1); - int mask_size = mask.has_value() ? mask->size(0) : 0; - const float* mask_ptr = mask.has_value() ? mask->data_ptr() : nullptr; - - int threads = 256; - int warps_per_block = threads / 32; - int blocks = (total_tokens + warps_per_block - 1) / warps_per_block; - int smem_size = (2 * num_experts + 1) * sizeof(float); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - moe_load_balance_warp_kernel<<>>( - reinterpret_cast(logits.data_ptr()), mask_ptr, - m_e.data_ptr(), c_e.data_ptr(), w_total.data_ptr(), - total_tokens, mask_size, num_experts, top_k - ); -} - -void launch_compute_local_loss(torch::Tensor m_e, torch::Tensor c_e, torch::Tensor w_total, torch::Tensor local_loss) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - compute_local_loss_kernel<<<1, 256, 0, stream>>>( - m_e.data_ptr(), c_e.data_ptr(), w_total.data_ptr(), - local_loss.data_ptr(), m_e.size(0) - ); -} - -void launch_symm_allreduce(torch::Tensor peer_ptrs, torch::Tensor out_val) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - symm_allreduce_scalar_kernel<<<1, 32, 0, stream>>>( - reinterpret_cast(peer_ptrs.data_ptr()), - out_val.data_ptr(), peer_ptrs.size(0) - ); -} - -void launch_moe_load_balance_backward( - torch::Tensor logits, std::optional mask, torch::Tensor m_e, - float w_total, torch::Tensor grad_logits, float grad_output_scaled, int num_experts -) { - int total_tokens = logits.size(0); - int mask_size = mask.has_value() ? mask->size(0) : 0; - const float* mask_ptr = mask.has_value() ? mask->data_ptr() : nullptr; - - int threads = 256; - int warps_per_block = threads / 32; - int blocks = (total_tokens + warps_per_block - 1) / warps_per_block; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - moe_load_balance_backward_kernel<<>>( - reinterpret_cast(logits.data_ptr()), mask_ptr, - m_e.data_ptr(), w_total, - reinterpret_cast<__nv_bfloat16*>(grad_logits.data_ptr()), - grad_output_scaled, total_tokens, mask_size, num_experts - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_moe_load_balance", &launch_moe_load_balance); - m.def("launch_compute_local_loss", &launch_compute_local_loss); - m.def("launch_symm_allreduce", &launch_symm_allreduce); - m.def("launch_moe_load_balance_backward", &launch_moe_load_balance_backward); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_load_balance_fast_ext", CUDA_SRC, extra_compile_args={'nvcc': ['-O3']}) - return _ext - - -_symm_cache = {} -def _get_symm_state(device: torch.device): - if device in _symm_cache: - return _symm_cache[device] - - buf = symm_mem.empty((1,), device=device, dtype=torch.float32) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - out = torch.empty((1,), device=device, dtype=torch.float32) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - _symm_cache[device] = (buf, hdl, out, ptrs_tensor) - return _symm_cache[device] - - -class CustomMoELoadBalanceLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, gate_logits, attention_mask, num_experts, top_k): - _get_ext() # warm up JIT - - if isinstance(gate_logits, (tuple, list)): - ctx.is_tuple = True - ctx.shapes = [g.shape for g in gate_logits] - ctx.devices = [g.device for g in gate_logits] - compute_device = gate_logits[0].device - logits = torch.cat([g.to(compute_device) for g in gate_logits], dim=0).contiguous() - else: - ctx.is_tuple = False - logits = gate_logits.contiguous() - - mask_tensor = None - if attention_mask is not None: - mask_tensor = attention_mask.reshape(-1).float().contiguous() - - ctx.save_for_backward(logits, mask_tensor) - ctx.num_experts = num_experts - - m_e = torch.zeros(num_experts, device=logits.device, dtype=torch.float32) - c_e = torch.zeros(num_experts, device=logits.device, dtype=torch.float32) - w_total = torch.zeros(1, device=logits.device, dtype=torch.float32) - - _get_ext().launch_moe_load_balance(logits, mask_tensor, m_e, c_e, w_total, top_k) - - ctx.m_e = m_e - ctx.w_total = w_total - - is_dist = dist.is_available() and dist.is_initialized() - ctx.world_size = dist.get_world_size() if is_dist else 1 - - if is_dist: - buf, hdl, out, ptrs_tensor = _get_symm_state(logits.device) - _get_ext().launch_compute_local_loss(m_e, c_e, w_total, buf) - - # Using channel-ordered barrier directly on the symmetric handle - hdl.barrier(channel=0) - _get_ext().launch_symm_allreduce(ptrs_tensor, out) - global_loss = out.clone().reshape(()) - else: - out = torch.empty(1, device=logits.device, dtype=torch.float32) - _get_ext().launch_compute_local_loss(m_e, c_e, w_total, out) - global_loss = out.reshape(()) - - return global_loss - - @staticmethod - def backward(ctx, grad_output): - logits, mask_tensor = ctx.saved_tensors - m_e = ctx.m_e - w_total_val = ctx.w_total.item() - num_experts = ctx.num_experts - - # Because we used a scalar sum all-reduce internally on forward (Loss = sum(Li)/N), - # its backward gradient passes locally out purely scaled by local distribution. - grad_output_scaled = grad_output.item() / ctx.world_size - grad_logits = torch.empty_like(logits) - - _get_ext().launch_moe_load_balance_backward( - logits, mask_tensor, m_e, w_total_val, grad_logits, - grad_output_scaled, num_experts - ) - - if ctx.is_tuple: - grads = [] - offset = 0 - for shape, device in zip(ctx.shapes, ctx.devices): - size = shape[0] - grads.append(grad_logits[offset:offset+size].to(device)) - offset += size - return tuple(grads), None, None, None - else: - return grad_logits, None, None, None - - -def solution( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]], - num_experts: int, - top_k: int = 2, - attention_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return CustomMoELoadBalanceLoss.apply(gate_logits, attention_mask, num_experts, top_k) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_cuda.py deleted file mode 100755 index 79a9220..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/25_importance_sampling_loss_cuda.py +++ /dev/null @@ -1,353 +0,0 @@ -""" -Strategy: -1. Kernel Fusion: Combines elementwise importance sampling ops (exp, clamp, mask) and local reductions (sum, min, max, entropy, kl) into a single custom CUDA kernel. This minimizes memory bandwidth and replaces 7 separate NCCL all-reduces. -2. Device-Side Communication: The local kernel atomically accumulates reductions directly into a `torch.distributed._symmetric_memory` buffer. A lightweight UVA kernel (`reduce_global_stats_kernel`) gathers and reduces the 7 scalar metrics from all peers via direct peer-to-peer access, avoiding host-driven `dist.all_reduce` overhead. -3. Compute-Communication Overlap: The cross-rank UVA reduction and symmetric memory barrier are offloaded to a dedicated communication stream. Concurrently, the default stream computes the PyTorch `local_surrogate_sum` (necessary for autograd). The streams synchronize right before assembling the final loss, effectively hiding the barrier and peer-read latency behind independent local computation. -""" - -import torch -import torch.nn.functional as F -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Tuple, Any -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__device__ __forceinline__ void atomicMinFloat(float* addr, float val) { - if (isnan(val)) return; - int* addr_as_i = (int*)addr; - int old = *addr_as_i, assumed; - do { - assumed = old; - if (__int_as_float(assumed) <= val) break; - old = atomicCAS(addr_as_i, assumed, __float_as_int(val)); - } while (assumed != old); -} - -__device__ __forceinline__ void atomicMaxFloat(float* addr, float val) { - if (isnan(val)) return; - int* addr_as_i = (int*)addr; - int old = *addr_as_i, assumed; - do { - assumed = old; - if (__int_as_float(assumed) >= val) break; - old = atomicCAS(addr_as_i, assumed, __float_as_int(val)); - } while (assumed != old); -} - -__global__ void init_stats_kernel(float* stats) { - if (threadIdx.x == 0) { - stats[0] = 0.0f; // n_valid - stats[1] = 0.0f; // pg_sum - stats[2] = 0.0f; // ratio_sum - stats[3] = INFINITY; // ratio_min - stats[4] = -INFINITY; // ratio_max - stats[5] = 0.0f; // k3_sum - stats[6] = 0.0f; // entropy_sum - stats[7] = 0.0f; // padding - } -} - -template -__global__ void compute_stats_kernel( - const scalar_t* __restrict__ per_token_ce, - const int64_t* __restrict__ labels, - const scalar_t* __restrict__ old_logprobs, - const scalar_t* __restrict__ advantages, - scalar_t* __restrict__ w_out, - scalar_t* __restrict__ per_token_pg_out, - scalar_t* __restrict__ per_token_logprobs_out, - float* __restrict__ local_stats, - int ignore_index, - int n -) { - __shared__ float s_n_valid[32]; - __shared__ float s_pg_sum[32]; - __shared__ float s_ratio_sum[32]; - __shared__ float s_ratio_min[32]; - __shared__ float s_ratio_max[32]; - __shared__ float s_k3_sum[32]; - __shared__ float s_entropy_sum[32]; - - int tid = threadIdx.x; - int lane = tid % 32; - int warp = tid / 32; - - float l_n_valid = 0; - float l_pg_sum = 0; - float l_ratio_sum = 0; - float l_ratio_min = INFINITY; - float l_ratio_max = -INFINITY; - float l_k3_sum = 0; - float l_entropy_sum = 0; - - int idx = blockIdx.x * blockDim.x + tid; - if (idx < n) { - int64_t label = labels[idx]; - float ce = static_cast(per_token_ce[idx]); - float old_lp = static_cast(old_logprobs[idx]); - float adv = static_cast(advantages[idx]); - - float new_lp = -ce; - per_token_logprobs_out[idx] = static_cast(new_lp); - - if (label != ignore_index) { - l_n_valid = 1.0f; - float delta = new_lp - old_lp; - delta = fmaxf(-20.0f, fminf(20.0f, delta)); - float ratio = expf(delta); - float pg = -ratio * adv; - - w_out[idx] = static_cast(ratio * adv); - per_token_pg_out[idx] = static_cast(pg); - - l_pg_sum = pg; - l_ratio_sum = ratio; - l_ratio_min = ratio; - l_ratio_max = ratio; - l_k3_sum = ratio - delta - 1.0f; - l_entropy_sum = ce; - } else { - w_out[idx] = static_cast(0.0f); - per_token_pg_out[idx] = static_cast(0.0f); - } - } - - // Warp reduction - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - l_n_valid += __shfl_down_sync(0xffffffff, l_n_valid, offset); - l_pg_sum += __shfl_down_sync(0xffffffff, l_pg_sum, offset); - l_ratio_sum += __shfl_down_sync(0xffffffff, l_ratio_sum, offset); - l_ratio_min = fminf(l_ratio_min, __shfl_down_sync(0xffffffff, l_ratio_min, offset)); - l_ratio_max = fmaxf(l_ratio_max, __shfl_down_sync(0xffffffff, l_ratio_max, offset)); - l_k3_sum += __shfl_down_sync(0xffffffff, l_k3_sum, offset); - l_entropy_sum += __shfl_down_sync(0xffffffff, l_entropy_sum, offset); - } - - if (lane == 0) { - s_n_valid[warp] = l_n_valid; - s_pg_sum[warp] = l_pg_sum; - s_ratio_sum[warp] = l_ratio_sum; - s_ratio_min[warp] = l_ratio_min; - s_ratio_max[warp] = l_ratio_max; - s_k3_sum[warp] = l_k3_sum; - s_entropy_sum[warp] = l_entropy_sum; - } - __syncthreads(); - - // Block reduction - if (warp == 0) { - int num_warps = blockDim.x / 32; - l_n_valid = (lane < num_warps) ? s_n_valid[lane] : 0; - l_pg_sum = (lane < num_warps) ? s_pg_sum[lane] : 0; - l_ratio_sum = (lane < num_warps) ? s_ratio_sum[lane] : 0; - l_ratio_min = (lane < num_warps) ? s_ratio_min[lane] : INFINITY; - l_ratio_max = (lane < num_warps) ? s_ratio_max[lane] : -INFINITY; - l_k3_sum = (lane < num_warps) ? s_k3_sum[lane] : 0; - l_entropy_sum = (lane < num_warps) ? s_entropy_sum[lane] : 0; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - l_n_valid += __shfl_down_sync(0xffffffff, l_n_valid, offset); - l_pg_sum += __shfl_down_sync(0xffffffff, l_pg_sum, offset); - l_ratio_sum += __shfl_down_sync(0xffffffff, l_ratio_sum, offset); - l_ratio_min = fminf(l_ratio_min, __shfl_down_sync(0xffffffff, l_ratio_min, offset)); - l_ratio_max = fmaxf(l_ratio_max, __shfl_down_sync(0xffffffff, l_ratio_max, offset)); - l_k3_sum += __shfl_down_sync(0xffffffff, l_k3_sum, offset); - l_entropy_sum += __shfl_down_sync(0xffffffff, l_entropy_sum, offset); - } - - if (lane == 0) { - atomicAdd(&local_stats[0], l_n_valid); - atomicAdd(&local_stats[1], l_pg_sum); - atomicAdd(&local_stats[2], l_ratio_sum); - if (l_ratio_min < INFINITY) atomicMinFloat(&local_stats[3], l_ratio_min); - if (l_ratio_max > -INFINITY) atomicMaxFloat(&local_stats[4], l_ratio_max); - atomicAdd(&local_stats[5], l_k3_sum); - atomicAdd(&local_stats[6], l_entropy_sum); - } - } -} - -__global__ void reduce_global_stats_kernel( - const long long* __restrict__ peer_ptrs, - float* __restrict__ global_out, - int world_size -) { - int tid = threadIdx.x; - if (tid >= 7) return; - - float val; - if (tid == 0 || tid == 1 || tid == 2 || tid == 5 || tid == 6) val = 0.0f; - else if (tid == 3) val = INFINITY; - else if (tid == 4) val = -INFINITY; - - for (int r = 0; r < world_size; r++) { - const float* peer_stats = (const float*)peer_ptrs[r]; - float p_val = peer_stats[tid]; - if (tid == 0 || tid == 1 || tid == 2 || tid == 5 || tid == 6) { - val += p_val; - } else if (tid == 3) { - val = fminf(val, p_val); - } else if (tid == 4) { - val = fmaxf(val, p_val); - } - } - - global_out[tid] = val; -} - -void launch_compute_stats( - torch::Tensor per_token_ce, - torch::Tensor labels, - torch::Tensor old_logprobs, - torch::Tensor advantages, - torch::Tensor w_out, - torch::Tensor per_token_pg_out, - torch::Tensor per_token_logprobs_out, - torch::Tensor local_stats, - int ignore_index -) { - int n = per_token_ce.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - init_stats_kernel<<<1, 1, 0, stream>>>(local_stats.data_ptr()); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, per_token_ce.scalar_type(), "compute_stats_kernel", ([&] { - compute_stats_kernel<<>>( - per_token_ce.data_ptr(), - labels.data_ptr(), - old_logprobs.data_ptr(), - advantages.data_ptr(), - w_out.data_ptr(), - per_token_pg_out.data_ptr(), - per_token_logprobs_out.data_ptr(), - local_stats.data_ptr(), - ignore_index, - n - ); - })); -} - -void launch_reduce_global_stats( - torch::Tensor peer_ptrs, - torch::Tensor global_out -) { - int world_size = peer_ptrs.size(0); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - reduce_global_stats_kernel<<<1, 32, 0, stream>>>( - (const long long*)peer_ptrs.data_ptr(), - global_out.data_ptr(), - world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_compute_stats", &launch_compute_stats, "Compute local elementwise metrics and local reductions"); - m.def("launch_reduce_global_stats", &launch_reduce_global_stats, "Reduce global metrics from peers via UVA"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("grpo_importance_sampling_ext", CUDA_SRC) - return _ext - -_resource_cache = None -def _get_resources(device): - global _resource_cache - if _resource_cache is not None: - return _resource_cache - - buf = symm_mem.empty((8,), device=device, dtype=torch.float32) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - global_out = torch.empty((8,), device=device, dtype=torch.float32) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - comm_stream = torch.cuda.Stream(device=device) - - _resource_cache = (buf, hdl, global_out, ptrs_tensor, comm_stream) - return _resource_cache - -def solution( - hidden_states: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - old_logprobs: torch.Tensor, - advantages: torch.Tensor, - ignore_index: int = -100, -) -> Tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor]: - - # 1. Compute logits and per-token cross entropy via heavily-optimized PyTorch ops - logits = F.linear(hidden_states, weight) - logits_flat = logits.view(-1, logits.size(-1)) - labels_flat = labels.contiguous().view(-1) - - # per_token_ce retains the grad_fn necessary for surrogate loss backpropagation - per_token_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=ignore_index, reduction='none') - per_token_ce_contig = per_token_ce.contiguous() if not per_token_ce.is_contiguous() else per_token_ce - - old_logprobs_flat = old_logprobs.contiguous().view(-1) - advantages_flat = advantages.contiguous().view(-1) - - w = torch.empty_like(per_token_ce_contig) - per_token_pg = torch.empty_like(per_token_ce_contig) - per_token_logprobs = torch.empty_like(per_token_ce_contig) - - buf, hdl, global_out, ptrs_tensor, comm_stream = _get_resources(hidden_states.device) - ext = _get_ext() - - # 2. Fuse all elementwise operations and local reductions into a single kernel run - ext.launch_compute_stats( - per_token_ce_contig, labels_flat, old_logprobs_flat, advantages_flat, - w, per_token_pg, per_token_logprobs, buf, ignore_index - ) - - ready_event = torch.cuda.Event() - ready_event.record() - - # 3. OVERLAP: Compute the autograd-tracked surrogate sum completely concurrently with cross-rank communication - # The kernel inherently detached `w`. `per_token_ce` will trace back through the graph. - local_surrogate_sum = (w * per_token_ce).sum() - - # 4. Device-side communication for global stats handling via UVA - with torch.cuda.stream(comm_stream): - comm_stream.wait_event(ready_event) - # Block-level sync over symmetric memory ensures all peers have populated local_stats - hdl.barrier(channel=0) - # Gather and reduce the 7 scalar metrics directly via UVA peer pointers - ext.launch_reduce_global_stats(ptrs_tensor, global_out) - done_event = torch.cuda.Event() - done_event.record() - - torch.cuda.current_stream().wait_event(done_event) - - # 5. Extract finalized global stats and formulate final composite metrics & loss - n_valid_global = global_out[0].clamp(min=1.0) - true_pg = global_out[1] / n_valid_global - ratio_mean = global_out[2] / n_valid_global - ratio_min = global_out[3] - ratio_max = global_out[4] - k3_mean = global_out[5] / n_valid_global - entropy_mean = global_out[6] / n_valid_global - - surrogate = local_surrogate_sum / n_valid_global - - # The loss triggers exactly the same gradients due to surrogate component - loss = true_pg.detach() + surrogate - surrogate.detach() - metrics = torch.stack([ratio_mean, ratio_min, ratio_max, k3_mean, entropy_mean]) - - per_token_logprobs = per_token_logprobs.view_as(labels) - per_token_loss = per_token_pg.view_as(labels) - - return loss, None, per_token_logprobs, per_token_loss, metrics \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_cuda.py deleted file mode 100755 index 675715c..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/26_moe_token_preprocess_cuda.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Strategy: -- We bypass the NCCL all_gather_into_tensor and multiple PyTorch reshape/sum ops. -- Kernel 1 (local_reduce): Reduces `expert_mask` locally to calculate `num_local_tokens_per_expert` and writes directly into a symmetric memory buffer, avoiding intermediate allocations. -- Device-side barrier: We use `hdl.barrier()` to asynchronously synchronize peers. -- Kernel 2 (gather_postprocess): A single thread block cooperatively loads all peers' symmetric buffers into shared memory over NVLink. It then computes `input_splits`, `output_splits`, `num_global_tokens_per_local_expert`, and `num_global_sum_tokens_per_local_expert` purely on-device from shared memory. -- Finally, the outputs are returned in the exact format required, with async CPU copies to overlap with subsequent host execution. -""" - -from typing import List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -__global__ void local_reduce_kernel( - const T* __restrict__ mask, - T* __restrict__ symm_buf, - int num_experts, - int N -) { - int expert_idx = blockIdx.x; - if (expert_idx >= num_experts) return; - - float sum = 0.0f; - for (int i = threadIdx.x; i < N; i += blockDim.x) { - sum += static_cast(mask[expert_idx * N + i]); - } - - static __shared__ float shared[32]; - int lane = threadIdx.x % 32; - int warp = threadIdx.x / 32; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - - if (lane == 0) { - shared[warp] = sum; - } - __syncthreads(); - - if (warp == 0) { - float warp_sum = (lane < (blockDim.x / 32)) ? shared[lane] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset); - } - // Add a system-level memory fence to ensure visibility across NVLink before barrier - if (lane == 0) { - symm_buf[expert_idx] = static_cast(warp_sum); - __threadfence_system(); - } - } -} - -template -__global__ void gather_postprocess_kernel( - const uint64_t* __restrict__ peer_ptrs, - T* __restrict__ global_tokens_local_expert, - T* __restrict__ global_sum_tokens, - float* __restrict__ input_splits, - float* __restrict__ output_splits, - int ep_size, - int num_experts, - int num_local_experts, - int rank -) { - extern __shared__ char smem[]; - T* s_all_bufs = reinterpret_cast(smem); - - int tid = threadIdx.x; - int total_elements = ep_size * num_experts; - - // Cooperatively load all peers' symmetric buffers into shared memory - for (int idx = tid; idx < total_elements; idx += blockDim.x) { - int r = idx / num_experts; - int e = idx % num_experts; - const T* peer_buf = reinterpret_cast(peer_ptrs[r]); - s_all_bufs[r * num_experts + e] = peer_buf[e]; - } - - __syncthreads(); - - // 1. global_tokens_local_expert - int out_elements = ep_size * num_local_experts; - for (int idx = tid; idx < out_elements; idx += blockDim.x) { - int r = idx / num_local_experts; - int i = idx % num_local_experts; - global_tokens_local_expert[idx] = s_all_bufs[r * num_experts + rank * num_local_experts + i]; - } - - // 2. global_sum_tokens - for (int i = tid; i < num_local_experts; i += blockDim.x) { - float sum = 0.0f; - for (int r = 0; r < ep_size; r++) { - sum += static_cast(s_all_bufs[r * num_experts + rank * num_local_experts + i]); - } - global_sum_tokens[i] = static_cast(sum); - } - - // 3. output_splits - for (int r = tid; r < ep_size; r += blockDim.x) { - float sum = 0.0f; - for (int i = 0; i < num_local_experts; i++) { - sum += static_cast(s_all_bufs[r * num_experts + rank * num_local_experts + i]); - } - output_splits[r] = sum; - } - - // 4. input_splits - for (int r = tid; r < ep_size; r += blockDim.x) { - float sum = 0.0f; - for (int i = 0; i < num_local_experts; i++) { - sum += static_cast(s_all_bufs[rank * num_experts + r * num_local_experts + i]); - } - input_splits[r] = sum; - } -} - -void launch_local_reduce( - torch::Tensor mask, - torch::Tensor symm_buf, - int num_experts, - int N -) { - int threads = 256; - int blocks = num_experts; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mask.scalar_type(), "local_reduce", [&] { - local_reduce_kernel<<>>( - mask.data_ptr(), - symm_buf.data_ptr(), - num_experts, - N - ); - }); -} - -void launch_gather_postprocess( - torch::Tensor peer_ptrs, - torch::Tensor global_tokens_local_expert, - torch::Tensor global_sum_tokens, - torch::Tensor input_splits, - torch::Tensor output_splits, - int ep_size, - int num_experts, - int num_local_experts, - int rank -) { - int threads = 256; - int blocks = 1; - int shared_mem_bytes = ep_size * num_experts * global_tokens_local_expert.element_size(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, global_tokens_local_expert.scalar_type(), "gather_postprocess", [&] { - gather_postprocess_kernel<<>>( - reinterpret_cast(peer_ptrs.data_ptr()), - global_tokens_local_expert.data_ptr(), - global_sum_tokens.data_ptr(), - input_splits.data_ptr(), - output_splits.data_ptr(), - ep_size, - num_experts, - num_local_experts, - rank - ); - }); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_local_reduce", &launch_local_reduce, "Local reduction kernel"); - m.def("launch_gather_postprocess", &launch_gather_postprocess, "Gather and postprocess kernel"); -} -''' - -_ext = None - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_token_preprocess_ext", CUDA_SRC) - return _ext - - -_symm_cache = {} - - -def _get_symm_resources(num_experts: int, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (num_experts, dtype, device, group) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(num_experts, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor) - _symm_cache[key] = res - return res - - -def solution( - expert_mask: torch.Tensor, - num_experts: int, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[List[int], List[int], torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - ep_size = group.size() - rank = dist.get_rank(group) - num_local_experts = num_experts // ep_size - - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - - ext = _get_ext() - - # Cast boolean/integer masks to float32 to support Tensor core floating point math traits - # if standard MoE implementations feed non-FP inputs. - if not expert_mask.is_floating_point(): - expert_mask = expert_mask.to(torch.float32) - - device = expert_mask.device - dtype = expert_mask.dtype - - buf, hdl, ptrs_tensor = _get_symm_resources(num_experts, dtype, device, group) - - expert_mask_c = expert_mask.contiguous() - N = expert_mask_c.numel() // num_experts - - # Kernel 1: Local reduction directly into symmetric memory - ext.launch_local_reduce(expert_mask_c, buf, num_experts, N) - - # Barrier to ensure peer NVLink visibility - hdl.barrier(channel=0) - - # Output allocation - global_tokens_local_expert = torch.empty((ep_size, num_local_experts), dtype=dtype, device=device) - global_sum_tokens = torch.empty((num_local_experts,), dtype=dtype, device=device) - input_splits = torch.empty((ep_size,), dtype=torch.float32, device=device) - output_splits = torch.empty((ep_size,), dtype=torch.float32, device=device) - - # Kernel 2: Gather sizes from peers' symmetric buffers and compute all outputs - ext.launch_gather_postprocess( - ptrs_tensor, - global_tokens_local_expert, - global_sum_tokens, - input_splits, - output_splits, - ep_size, - num_experts, - num_local_experts, - rank - ) - - # Move outputs to CPU to overlap with host-side logic - input_splits_cpu = input_splits.to(torch.int).tolist() - output_splits_cpu = output_splits.to(torch.int).tolist() - out_tokens = global_tokens_local_expert.to(torch.device("cpu"), non_blocking=True) - out_sum = global_sum_tokens.to(torch.device("cpu"), non_blocking=True) - - return input_splits_cpu, output_splits_cpu, out_tokens, out_sum \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_cuda.py deleted file mode 100755 index de5d69e..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/27_moe_all2all_primitive_cuda.py +++ /dev/null @@ -1,254 +0,0 @@ -from typing import List, Optional, Union - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void pull_all_to_all_kernel( - const uint64_t* __restrict__ remote_data_ptrs, - const uint64_t* __restrict__ remote_offset_ptrs, - const int64_t* __restrict__ my_out_offsets, - const int64_t* __restrict__ my_output_splits, - T* __restrict__ out_data, - int64_t hidden_dim, - int world_size, - int rank -) { - int p = blockIdx.y; // peer index - if (p >= world_size) return; - - int64_t pull_rows = my_output_splits[p]; - if (pull_rows == 0) return; - - __shared__ int64_t shared_remote_offset; - if (threadIdx.x == 0) { - // Read remote offset from the peer's offset buffer natively via UVA - const int64_t* peer_offset_buf = reinterpret_cast(remote_offset_ptrs[p]); - shared_remote_offset = peer_offset_buf[rank]; - } - __syncthreads(); - - int64_t remote_offset = shared_remote_offset; - int64_t my_out_offset = my_out_offsets[p]; - - const T* peer_data = reinterpret_cast(remote_data_ptrs[p]); - const T* src = peer_data + remote_offset * hidden_dim; - T* dst = out_data + my_out_offset * hidden_dim; - - int64_t total_elements = pull_rows * hidden_dim; - - // Fast path: 128-bit aligned vectorized loads across NVLink - if (reinterpret_cast(src) % 16 == 0 && - reinterpret_cast(dst) % 16 == 0 && - total_elements % 8 == 0) - { - int64_t total_vecs = total_elements / 8; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - const ulonglong2* src_vec = reinterpret_cast(src); - ulonglong2* dst_vec = reinterpret_cast(dst); - - for (; idx < total_vecs; idx += (int64_t)gridDim.x * blockDim.x) { - dst_vec[idx] = src_vec[idx]; - } - } else { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < total_elements; idx += (int64_t)gridDim.x * blockDim.x) { - dst[idx] = src[idx]; - } - } -} - -void launch_pull_all_to_all( - torch::Tensor remote_data_ptrs_tensor, - torch::Tensor remote_offset_ptrs_tensor, - torch::Tensor my_out_offsets_tensor, - torch::Tensor my_output_splits_tensor, - torch::Tensor out_data, - int64_t hidden_dim, - int world_size, - int rank -) { - const uint64_t* remote_data_ptrs = reinterpret_cast(remote_data_ptrs_tensor.data_ptr()); - const uint64_t* remote_offset_ptrs = reinterpret_cast(remote_offset_ptrs_tensor.data_ptr()); - const int64_t* my_out_offsets = my_out_offsets_tensor.data_ptr(); - const int64_t* my_output_splits = my_output_splits_tensor.data_ptr(); - - int threads = 512; - int blocks_x = 32; // Over-subscribe SMs to hide latency - dim3 grid(blocks_x, world_size); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (out_data.dtype() == torch::kBFloat16) { - __nv_bfloat16* out_ptr = reinterpret_cast<__nv_bfloat16*>(out_data.data_ptr()); - pull_all_to_all_kernel<__nv_bfloat16><<>>( - remote_data_ptrs, remote_offset_ptrs, my_out_offsets, my_output_splits, - out_ptr, hidden_dim, world_size, rank - ); - } else if (out_data.dtype() == torch::kFloat16) { - half* out_ptr = reinterpret_cast(out_data.data_ptr()); - pull_all_to_all_kernel<<>>( - remote_data_ptrs, remote_offset_ptrs, my_out_offsets, my_output_splits, - out_ptr, hidden_dim, world_size, rank - ); - } else if (out_data.dtype() == torch::kFloat32) { - float* out_ptr = out_data.data_ptr(); - pull_all_to_all_kernel<<>>( - remote_data_ptrs, remote_offset_ptrs, my_out_offsets, my_output_splits, - out_ptr, hidden_dim, world_size, rank - ); - } else { - TORCH_CHECK(false, "Unsupported dtype, must be bf16, fp16, or fp32"); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pull_all_to_all", &launch_pull_all_to_all, "UVA Pull All-to-All kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_all2all_pull_ext", CUDA_SRC) - return _ext - -class SymmMemCache: - def __init__(self): - self.max_tokens = 0 - self.hidden_dim = 0 - self.data_buf = None - self.data_hdl = None - self.offset_buf = None - self.offset_hdl = None - self.data_ptrs = None - self.offset_ptrs = None - -_cache = SymmMemCache() - -def _get_symm_buffers(num_tokens: int, hidden_dim: int, device: torch.device, dtype: torch.dtype, world_size: int, group: dist.ProcessGroup): - global _cache - if _cache.data_buf is None: - # Initial allocation - make it huge to avoid costly runtime reallocations/all_reduce barriers. - local_info = torch.tensor([num_tokens, hidden_dim], dtype=torch.int64, device=device) - dist.all_reduce(local_info, op=dist.ReduceOp.MAX, group=group) - global_tokens = local_info[0].item() - global_hidden = local_info[1].item() - - # 8x upper bound margin to absorb MoE imbalance/load spikes. - # Fallback 262144 ensures tiny initial batches don't undershoot later huge inputs. - new_max = max(global_tokens * 8, 262144) - _cache.max_tokens = new_max - _cache.hidden_dim = global_hidden - - _cache.data_buf = symm_mem.empty((new_max, global_hidden), device=device, dtype=dtype) - _cache.data_hdl = symm_mem.rendezvous(_cache.data_buf, group) - _cache.data_ptrs = torch.tensor(_cache.data_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - _cache.offset_buf = symm_mem.empty((world_size,), device=device, dtype=torch.int64) - _cache.offset_hdl = symm_mem.rendezvous(_cache.offset_buf, group) - _cache.offset_ptrs = torch.tensor(_cache.offset_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - elif num_tokens > _cache.max_tokens or hidden_dim != _cache.hidden_dim: - raise RuntimeError( - f"Dynamic reallocation needed but omitted to avoid deadlock. " - f"num_tokens={num_tokens} exceeds max {_cache.max_tokens} or hidden_dim={hidden_dim} changed." - ) - - return _cache - -@torch.no_grad() -def solution( - local_tensor: torch.Tensor, - input_split_sizes: Optional[Union[List[int], torch.Tensor]] = None, - output_split_sizes: Optional[Union[List[int], torch.Tensor]] = None, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - - if world_size == 1: - return local_tensor.contiguous() - - if _ext is None: - if dist.get_rank(group) == 0: - _get_ext() - dist.barrier(group=group) - - rank = dist.get_rank(group) - device = local_tensor.device - dtype = local_tensor.dtype - local_tensor = local_tensor.contiguous() - num_tokens = local_tensor.size(0) - hidden_dim = local_tensor.size(1) - - # 1. Compute split sizes and block offsets purely on the device - if input_split_sizes is None: - split = num_tokens // world_size - input_splits_dev = torch.full((world_size,), split, dtype=torch.int64, device=device) - elif isinstance(input_split_sizes, list): - input_splits_dev = torch.tensor(input_split_sizes, dtype=torch.int64, device=device) - else: - input_splits_dev = input_split_sizes.to(device) - - input_offsets_dev = torch.zeros(world_size, dtype=torch.int64, device=device) - input_offsets_dev[1:] = torch.cumsum(input_splits_dev[:-1], dim=0) - - if output_split_sizes is None: - split = num_tokens // world_size - output_splits_dev = torch.full((world_size,), split, dtype=torch.int64, device=device) - out_size = num_tokens - elif isinstance(output_split_sizes, list): - output_splits_dev = torch.tensor(output_split_sizes, dtype=torch.int64, device=device) - out_size = sum(output_split_sizes) - else: - output_splits_dev = output_split_sizes.to(device) - out_size = int(output_splits_dev.sum().item()) - - output_offsets_dev = torch.zeros(world_size, dtype=torch.int64, device=device) - output_offsets_dev[1:] = torch.cumsum(output_splits_dev[:-1], dim=0) - - output = torch.empty((out_size, hidden_dim), dtype=dtype, device=device) - - if out_size == 0 and num_tokens == 0: - return output - - # 2. Acquire persistent symmetric layout structures - cache = _get_symm_buffers(num_tokens, hidden_dim, device, dtype, world_size, group) - - # 3. Synchronize stream before populating new data to ensure no ongoing trailing reads from peers - cache.data_hdl.barrier(channel=0) - - # 4. Fill local symmetric views - if num_tokens > 0: - cache.data_buf[:num_tokens].copy_(local_tensor) - cache.offset_buf[:world_size].copy_(input_offsets_dev) - - # 5. Synchronize after populating the arrays to assure readiness globally - cache.data_hdl.barrier(channel=0) - - # 6. Execute highly parallel CUDA pull directly bypassing NCCL - if out_size > 0: - _get_ext().launch_pull_all_to_all( - cache.data_ptrs, - cache.offset_ptrs, - output_offsets_dev, - output_splits_dev, - output, - hidden_dim, - world_size, - rank - ) - - return output \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_cuda.py deleted file mode 100755 index c301d0a..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/28_moe_pre_all2all_cuda.py +++ /dev/null @@ -1,294 +0,0 @@ -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -__global__ void push_kernel_tokens( - const scalar_t* __restrict__ hidden_states, - const int64_t* __restrict__ sorted_indices, - const int* __restrict__ send_offsets, - const int* __restrict__ send_counts, - const int64_t* __restrict__ peer_meta_bufs, - const int64_t* __restrict__ peer_recv_bufs, - int num_experts, - int num_local_experts, - int hidden_dim, - int my_rank, - int total_tokens -) { - int token_idx = blockIdx.x * blockDim.y + threadIdx.y; - if (token_idx >= total_tokens) return; - - // Binary search send_offsets to find the destination expert for this token - int L = 0, R = num_experts - 1; - int E = 0; - while (L <= R) { - int mid = (L + R) / 2; - if (send_offsets[mid] <= token_idx) { - E = mid; - L = mid + 1; - } else { - R = mid - 1; - } - } - - int r = E / num_local_experts; - int e = E % num_local_experts; - - int src_token = sorted_indices[token_idx]; - - // Read remote target offset from the peer's meta_buf - const int* remote_meta = reinterpret_cast(peer_meta_bufs[r]); - int dest_base = remote_meta[my_rank * num_local_experts + e]; - int dest_token = dest_base + (token_idx - send_offsets[E]); - - const scalar_t* src_row = hidden_states + src_token * hidden_dim; - scalar_t* dest_row = reinterpret_cast(peer_recv_bufs[r]) + dest_token * hidden_dim; - - int tid = threadIdx.x; - int stride = blockDim.x; - - int bytes = hidden_dim * sizeof(scalar_t); - // Use 128-bit, 64-bit, or 32-bit vectorized memory accesses natively - if (bytes % 16 == 0) { - const float4* src_vec = reinterpret_cast(src_row); - float4* dest_vec = reinterpret_cast(dest_row); - int vec_dim = bytes / 16; - for (int i = tid; i < vec_dim; i += stride) { - dest_vec[i] = src_vec[i]; - } - } else if (bytes % 8 == 0) { - const float2* src_vec = reinterpret_cast(src_row); - float2* dest_vec = reinterpret_cast(dest_row); - int vec_dim = bytes / 8; - for (int i = tid; i < vec_dim; i += stride) { - dest_vec[i] = src_vec[i]; - } - } else if (bytes % 4 == 0) { - const float* src_vec = reinterpret_cast(src_row); - float* dest_vec = reinterpret_cast(dest_row); - int vec_dim = bytes / 4; - for (int i = tid; i < vec_dim; i += stride) { - dest_vec[i] = src_vec[i]; - } - } else { - for (int i = tid; i < hidden_dim; i += stride) { - dest_row[i] = src_row[i]; - } - } -} - -void launch_push_tokens( - torch::Tensor hidden_states, - torch::Tensor sorted_indices, - torch::Tensor send_offsets, - torch::Tensor send_counts, - torch::Tensor peer_meta_bufs, - torch::Tensor peer_recv_bufs, - int num_experts, - int num_local_experts, - int hidden_dim, - int my_rank, - int total_tokens -) { - if (total_tokens == 0) return; - - // 2D Block mapping: 32 threads over feature dimension, 8 tokens per block - int threads_x = 32; - int threads_y = 8; - dim3 block(threads_x, threads_y); - int grid = (total_tokens + threads_y - 1) / threads_y; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - auto dtype = hidden_states.scalar_type(); - - if (dtype == torch::kBFloat16) { - push_kernel_tokens<__nv_bfloat16><<>>( - reinterpret_cast(hidden_states.data_ptr()), - sorted_indices.data_ptr(), - send_offsets.data_ptr(), - send_counts.data_ptr(), - peer_meta_bufs.data_ptr(), - peer_recv_bufs.data_ptr(), - num_experts, num_local_experts, hidden_dim, my_rank, total_tokens); - } else if (dtype == torch::kFloat16) { - push_kernel_tokens<<>>( - reinterpret_cast(hidden_states.data_ptr()), - sorted_indices.data_ptr(), - send_offsets.data_ptr(), - send_counts.data_ptr(), - peer_meta_bufs.data_ptr(), - peer_recv_bufs.data_ptr(), - num_experts, num_local_experts, hidden_dim, my_rank, total_tokens); - } else if (dtype == torch::kFloat32) { - push_kernel_tokens<<>>( - reinterpret_cast(hidden_states.data_ptr()), - sorted_indices.data_ptr(), - send_offsets.data_ptr(), - send_counts.data_ptr(), - peer_meta_bufs.data_ptr(), - peer_recv_bufs.data_ptr(), - num_experts, num_local_experts, hidden_dim, my_rank, total_tokens); - } else { - TORCH_CHECK(false, "Unsupported dtype for push_tokens kernel."); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push_tokens", &launch_push_tokens, "PUSH tokens directly to symmetric memory final offsets"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_push_tokens_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(max_tokens, hidden_dim, dtype, device, world_size, num_local_experts): - key = (hidden_dim, dtype, device, world_size, num_local_experts) - if key in _symm_cache: - res = _symm_cache[key] - if res['max_tokens'] >= max_tokens: - return res - - # Re-allocate only if capacities exceed bounds - recv_buf = symm_mem.empty((max_tokens, hidden_dim), dtype=dtype, device=device) - hdl_recv = symm_mem.rendezvous(recv_buf, dist.group.WORLD) - - meta_buf = symm_mem.empty((world_size, num_local_experts), dtype=torch.int32, device=device) - hdl_meta = symm_mem.rendezvous(meta_buf, dist.group.WORLD) - - peer_meta_ptrs = torch.tensor(hdl_meta.buffer_ptrs, dtype=torch.int64, device=device) - peer_recv_ptrs = torch.tensor(hdl_recv.buffer_ptrs, dtype=torch.int64, device=device) - - res = { - 'max_tokens': max_tokens, - 'recv_buf': recv_buf, - 'hdl_recv': hdl_recv, - 'meta_buf': meta_buf, - 'hdl_meta': hdl_meta, - 'peer_meta_ptrs': peer_meta_ptrs, - 'peer_recv_ptrs': peer_recv_ptrs - } - _symm_cache[key] = res - return res - - -def solution( - hidden_states: torch.Tensor, - expert_mask: torch.Tensor, - num_experts: int, - input_splits: Union[List[int], torch.Tensor], - output_splits: Union[List[int], torch.Tensor], - num_global_tokens_per_local_expert: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Size]: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - my_rank = dist.get_rank(group) - - hidden_dim = hidden_states.size(-1) - hidden_states = hidden_states.reshape(-1, hidden_dim) - if not hidden_states.is_contiguous(): - hidden_states = hidden_states.contiguous() - org_hidden_states_shape = hidden_states.shape - num_tokens = hidden_states.size(0) - device = hidden_states.device - num_local_experts = num_experts // world_size - - routing_map = expert_mask.sum(dim=1) - routing_map_bool = routing_map.bool() - - # --------------------------------------------------------- - # Single GPU Short-circuit (No collective overhead) - # --------------------------------------------------------- - if world_size == 1: - token_indices = torch.arange(num_tokens, device=device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map_bool) - local_permuted = hidden_states.index_select(0, sorted_indices) - - expected_tokens = sum(input_splits) if isinstance(input_splits, list) else int(input_splits.sum().item()) - actual_tokens = sorted_indices.size(0) - if expected_tokens != actual_tokens: - raise RuntimeError(f"EP split mismatch: input_splits sum ({expected_tokens}) != permuted tokens ({actual_tokens})") - return local_permuted, routing_map, sorted_indices, org_hidden_states_shape - - # Max bound ensuring enough space even during intense routing load variations - max_tokens = world_size * num_tokens - - buf = _get_symm_state(max_tokens, hidden_dim, hidden_states.dtype, device, world_size, num_local_experts) - recv_buf = buf['recv_buf'] - meta_buf = buf['meta_buf'] - hdl_recv = buf['hdl_recv'] - hdl_meta = buf['hdl_meta'] - - # --------------------------------------------------------- - # 1. Receiver Meta Computation & Symmetric Scatter - # Compute the final destination offsets for peer chunks. - # --------------------------------------------------------- - N = num_global_tokens_per_local_expert - expert_sizes = N.sum(dim=0) - expert_base_offsets = expert_sizes.cumsum(dim=0) - expert_sizes - rank_offsets_within_expert = N.cumsum(dim=0) - N - dest_offsets = (expert_base_offsets.unsqueeze(0) + rank_offsets_within_expert).to(torch.int32) - - meta_buf.copy_(dest_offsets) - hdl_meta.barrier(channel=0) # Sync independent metadata fast track - - # --------------------------------------------------------- - # 2. Sender Local Prep - # Overlapping execution locally while peers expose pointers. - # --------------------------------------------------------- - send_counts = routing_map_bool.sum(dim=1, dtype=torch.int32) - send_offsets = send_counts.cumsum(dim=0) - send_counts - - token_indices = torch.arange(num_tokens, device=device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map_bool) - - total_send_tokens = sorted_indices.size(0) - expected_tokens = sum(input_splits) if isinstance(input_splits, list) else int(input_splits.sum().item()) - if expected_tokens != total_send_tokens: - raise RuntimeError(f"EP split mismatch: input_splits sum ({expected_tokens}) != permuted tokens ({total_send_tokens})") - - # --------------------------------------------------------- - # 3. Custom Fused Push Scatter Operator - # --------------------------------------------------------- - _get_ext().launch_push_tokens( - hidden_states, - sorted_indices, - send_offsets, - send_counts, - buf['peer_meta_ptrs'], - buf['peer_recv_ptrs'], - num_experts, - num_local_experts, - hidden_dim, - my_rank, - total_send_tokens - ) - - hdl_recv.barrier(channel=0) - - # --------------------------------------------------------- - # 4. Final Output Construction - # --------------------------------------------------------- - total_recv_tokens = int(N.sum().item()) - global_permuted_hidden_states = recv_buf[:total_recv_tokens].clone() - - return global_permuted_hidden_states, routing_map, sorted_indices, org_hidden_states_shape \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_cuda.py deleted file mode 100755 index 4c416c1..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/29_moe_post_all2all_cuda.py +++ /dev/null @@ -1,262 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import List, Optional, Union -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Direct PUSH: Reads from unsorted local expert_outputs and writes directly to remote recv_buf. -__global__ void push_kernel_vec( - const __nv_bfloat16* __restrict__ expert_outputs, - const int32_t* __restrict__ meta_info, - const uint64_t* __restrict__ recv_buf_ptrs, - int E, - int hidden_dim -) { - int k = blockIdx.y; - if (k >= E) return; - - int src_offset = meta_info[k * 4 + 0]; - int dest_offset = meta_info[k * 4 + 1]; - int size = meta_info[k * 4 + 2]; - int dest_rank = meta_info[k * 4 + 3]; - - if (size == 0) return; - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - - if (hidden_dim % 8 == 0) { - int total_vecs = (size * hidden_dim) / 8; - const float4* src = reinterpret_cast(expert_outputs + src_offset * hidden_dim); - float4* dest = reinterpret_cast( - reinterpret_cast<__nv_bfloat16*>(recv_buf_ptrs[dest_rank]) + dest_offset * hidden_dim - ); - for (int i = tid; i < total_vecs; i += stride) { - dest[i] = src[i]; - } - } else { - int total_elements = size * hidden_dim; - const __nv_bfloat16* src = expert_outputs + src_offset * hidden_dim; - __nv_bfloat16* dest = reinterpret_cast<__nv_bfloat16*>(recv_buf_ptrs[dest_rank]) + dest_offset * hidden_dim; - for (int i = tid; i < total_elements; i += stride) { - dest[i] = src[i]; - } - } -} - -// Fuses the elementwise multiply with the routing weight and the atomic scatter_add -__global__ void unpermute_fused_kernel( - const __nv_bfloat16* __restrict__ recv_buf, - const __nv_bfloat16* __restrict__ tokens_weight, - const int64_t* __restrict__ permutation_mapping, - __nv_bfloat16* __restrict__ unpermuted_tokens, - int total_received, - int hidden_dim -) { - int token_idx = blockIdx.x; - if (token_idx >= total_received) return; - - int orig_idx = permutation_mapping[token_idx]; - float weight = __bfloat162float(tokens_weight[token_idx]); - - const __nv_bfloat16* src = recv_buf + token_idx * hidden_dim; - __nv_bfloat16* dst = unpermuted_tokens + orig_idx * hidden_dim; - - for (int d = threadIdx.x; d < hidden_dim; d += blockDim.x) { - float val = __bfloat162float(src[d]) * weight; - #if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) - atomicAdd(dst + d, __float2bfloat16(val)); - #endif - } -} - -void launch_push( - torch::Tensor expert_outputs, - torch::Tensor meta_info, - torch::Tensor recv_buf_ptrs, - int E, - int hidden_dim -) { - const __nv_bfloat16* src = reinterpret_cast(expert_outputs.data_ptr()); - const int32_t* meta = meta_info.data_ptr(); - const uint64_t* ptrs = reinterpret_cast(recv_buf_ptrs.data_ptr()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - dim3 grid(32, E); - dim3 block(256); - push_kernel_vec<<>>(src, meta, ptrs, E, hidden_dim); -} - -void launch_unpermute( - torch::Tensor recv_buf, - torch::Tensor tokens_weight, - torch::Tensor permutation_mapping, - torch::Tensor unpermuted_tokens, - int total_received, - int hidden_dim -) { - if (total_received == 0) return; - - const __nv_bfloat16* src = reinterpret_cast(recv_buf.data_ptr()); - const __nv_bfloat16* weights = reinterpret_cast(tokens_weight.data_ptr()); - const int64_t* mapping = permutation_mapping.data_ptr(); - __nv_bfloat16* dst = reinterpret_cast<__nv_bfloat16*>(unpermuted_tokens.data_ptr()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - dim3 grid(total_received); - dim3 block(256); - - unpermute_fused_kernel<<>>( - src, weights, mapping, dst, total_received, hidden_dim - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push", &launch_push, "Push chunks to remote symmetric recv_buf"); - m.def("launch_unpermute", &launch_unpermute, "Fused weight unpermute and scatter_add"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_post_all2all_ext", CUDA_SRC) - return _ext - -_comm_stream = None -def _get_comm_stream(): - global _comm_stream - if _comm_stream is None: - _comm_stream = torch.cuda.Stream() - return _comm_stream - -@torch.no_grad() -def solution( - expert_outputs: torch.Tensor, - routing_weights: torch.Tensor, - selected_experts: torch.Tensor, - num_experts: int, - input_splits: Union[List[int], torch.Tensor], - output_splits: Union[List[int], torch.Tensor], - num_global_tokens_per_local_expert: torch.Tensor, - routing_map: torch.Tensor, - local_input_permutation_mapping: torch.Tensor, - org_hidden_states_shape: torch.Size, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - W = dist.get_world_size(group) - rank = dist.get_rank(group) - device = expert_outputs.device - - expert_outputs = expert_outputs.contiguous() - if expert_outputs.dtype != torch.bfloat16: - expert_outputs = expert_outputs.to(torch.bfloat16) - - hidden_dim = expert_outputs.size(1) - - input_splits_list = input_splits.tolist() if isinstance(input_splits, torch.Tensor) else input_splits - out_size = sum(input_splits_list) - - # Fast path for single rank (W=1) - if W == 1: - # Standard unpermute block, ignoring P2P logic - unpermuted_tokens = torch.zeros(org_hidden_states_shape, dtype=torch.bfloat16, device=device) - weights_idx = torch.zeros((routing_weights.size(0), num_experts), dtype=routing_weights.dtype, device=device) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - tokens_weight = weights_idx.T.contiguous().masked_select(routing_map.bool()).to(torch.bfloat16) - - # Sort using Python lists - L = num_experts - split_sizes = num_global_tokens_per_local_expert.T.ravel().tolist() if isinstance(num_global_tokens_per_local_expert, torch.Tensor) else num_global_tokens_per_local_expert - unpermute_order = torch.arange(num_experts).reshape(L, -1).T.ravel().tolist() - - chunks = torch.split(expert_outputs, split_sizes, dim=0) - recv_buf = torch.cat([chunks[i] for i in unpermute_order], dim=0) - - _get_ext().launch_unpermute( - recv_buf, tokens_weight, local_input_permutation_mapping.to(torch.int64), - unpermuted_tokens.view(-1, hidden_dim), out_size, hidden_dim - ) - return unpermuted_tokens - - # --- P2P MULTI-GPU PIPELINE --- - - # 1. Swiftly exchange split counts to precisely know remote destination offsets - output_splits_t = output_splits.to(torch.int32) if isinstance(output_splits, torch.Tensor) else torch.tensor(output_splits, dtype=torch.int32, device=device) - gathered_splits = torch.empty(W * W, dtype=torch.int32, device=device) - dist.all_gather_into_tensor(gathered_splits, output_splits_t, group=group) - gathered_splits = gathered_splits.view(W, W) - - # 2. Build explicit map of where every chunk needs to land remotely (compute safely on CPU) - E = num_experts - L = E // W - split_sizes = num_global_tokens_per_local_expert.T.ravel().tolist() if isinstance(num_global_tokens_per_local_expert, torch.Tensor) else num_global_tokens_per_local_expert - - chunk_src_offsets = [0] * E - curr = 0 - for i in range(E): - chunk_src_offsets[i] = curr - curr += split_sizes[i] - - dest_offsets = gathered_splits[:rank, :].sum(dim=0).tolist() - curr_dest_offsets = dest_offsets.copy() - unpermute_order = torch.arange(E).reshape(L, -1).T.ravel().tolist() - - meta_info_cpu = torch.zeros((E, 4), dtype=torch.int32) - for k in range(E): - dest_rank = k // L - orig_idx = unpermute_order[k] - size = split_sizes[orig_idx] - meta_info_cpu[k, 0] = chunk_src_offsets[orig_idx] - meta_info_cpu[k, 1] = curr_dest_offsets[dest_rank] - meta_info_cpu[k, 2] = size - meta_info_cpu[k, 3] = dest_rank - curr_dest_offsets[dest_rank] += size - - meta_info = meta_info_cpu.to(device, non_blocking=True) - - # 3. Setup Symmetric Memory Buffer - recv_buf = symm_mem.empty((out_size, hidden_dim), dtype=torch.bfloat16, device=device) - hdl = symm_mem.rendezvous(recv_buf, group) - hdl.barrier(channel=0) - - # 4. Overlapped Async Network PUSH - # Reads unordered `expert_outputs` and correctly sorts *during* the NVLink PUSH copy. - comm_stream = _get_comm_stream() - comm_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(comm_stream): - recv_buf_ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - _get_ext().launch_push(expert_outputs, meta_info, recv_buf_ptrs, E, hidden_dim) - - # 5. Hide routing math latency behind the PUSH using Default Stream - num_tokens = routing_weights.size(0) - weights_idx = torch.zeros((num_tokens, num_experts), dtype=routing_weights.dtype, device=device) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - tokens_weight = weights_idx.T.contiguous().masked_select(routing_map.bool()).to(torch.bfloat16) - - # 6. Global P2P Finalization - torch.cuda.current_stream().wait_stream(comm_stream) - hdl.barrier(channel=0) - - # 7. Execute Native Fused Unpermute - unpermuted_tokens = torch.zeros(org_hidden_states_shape, dtype=torch.bfloat16, device=device) - _get_ext().launch_unpermute( - recv_buf, - tokens_weight, - local_input_permutation_mapping.to(torch.int64), - unpermuted_tokens.view(-1, hidden_dim), - out_size, - hidden_dim - ) - - return unpermuted_tokens \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_cuda.py deleted file mode 100755 index c31e77c..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/2_allgather_cuda.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Strategy: -1. **Device-side Peer Pulls (UVA)**: Allocate a symmetric memory buffer per rank to hold its input. A custom CUDA kernel uses vectorized reads (up to 128-bit) to pull directly from remote peers' symmetric buffers over NVLink into the local output tensor, sidestepping NCCL overhead entirely. -2. **Compute–Communication Overlap**: The local chunk copy (`tensor` -> `out[rank]`) is scheduled asynchronously on the stream before the inter-GPU synchronization (`hdl.barrier`), allowing it to overlap with peer pulls and barrier waits. -3. **Dynamic Alignment**: The pull kernel dynamically inspects pointer alignment to utilize 128-bit, 64-bit, or 32-bit memory instructions, ensuring max NVLink bandwidth regardless of the arbitrary input shape. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void allgather_pull_kernel( - const uint64_t* __restrict__ peer_ptrs, - void* __restrict__ out, - int64_t bytes_per_rank, - int world_size, - int my_rank -) { - int rank_to_read = blockIdx.y; - if (rank_to_read == my_rank) return; - - const char* src = reinterpret_cast(static_cast(peer_ptrs[rank_to_read])); - char* dst = reinterpret_cast(out) + rank_to_read * bytes_per_rank; - - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - // Dynamically fallback to smaller vectorized loads if shapes enforce unaligned pointers - if (((uintptr_t)src % 16 == 0) && ((uintptr_t)dst % 16 == 0)) { - int64_t numel_v = bytes_per_rank / 16; - const uint4* src_v = reinterpret_cast(src); - uint4* dst_v = reinterpret_cast(dst); - - for (int64_t i = tid; i < numel_v; i += stride) { - dst_v[i] = src_v[i]; - } - - int64_t rem_start = numel_v * 16; - for (int64_t i = rem_start + tid; i < bytes_per_rank; i += stride) { - dst[i] = src[i]; - } - } else if (((uintptr_t)src % 8 == 0) && ((uintptr_t)dst % 8 == 0)) { - int64_t numel_v = bytes_per_rank / 8; - const uint2* src_v = reinterpret_cast(src); - uint2* dst_v = reinterpret_cast(dst); - - for (int64_t i = tid; i < numel_v; i += stride) { - dst_v[i] = src_v[i]; - } - - int64_t rem_start = numel_v * 8; - for (int64_t i = rem_start + tid; i < bytes_per_rank; i += stride) { - dst[i] = src[i]; - } - } else if (((uintptr_t)src % 4 == 0) && ((uintptr_t)dst % 4 == 0)) { - int64_t numel_v = bytes_per_rank / 4; - const uint32_t* src_v = reinterpret_cast(src); - uint32_t* dst_v = reinterpret_cast(dst); - - for (int64_t i = tid; i < numel_v; i += stride) { - dst_v[i] = src_v[i]; - } - - int64_t rem_start = numel_v * 4; - for (int64_t i = rem_start + tid; i < bytes_per_rank; i += stride) { - dst[i] = src[i]; - } - } else { - // Safe scalar fallback - for (int64_t i = tid; i < bytes_per_rank; i += stride) { - dst[i] = src[i]; - } - } -} - -void launch_allgather( - torch::Tensor peer_ptrs_tensor, - torch::Tensor out, - int64_t bytes_per_rank, - int world_size, - int my_rank -) { - const uint64_t* d_ptrs = reinterpret_cast(peer_ptrs_tensor.data_ptr()); - - int threads = 256; - int64_t numel_v = bytes_per_rank / 16; - int blocks_x = std::min((int64_t)256, (numel_v + threads - 1) / threads); - if (blocks_x <= 0) blocks_x = 1; - - // gridDim.y handles each peer rank's memory pull mapping smoothly - dim3 blocks(blocks_x, world_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - allgather_pull_kernel<<>>( - d_ptrs, out.data_ptr(), bytes_per_rank, world_size, my_rank - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_allgather", &launch_allgather, "UVA Pull AllGather"); -} -''' - -_ext = None -_ext_compiled = False - -def _get_ext(): - global _ext, _ext_compiled - if not _ext_compiled: - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - _ext = compile_cuda_extension("allgather_pull_ext", CUDA_SRC) - if dist.is_initialized(): - dist.barrier() - if rank != 0: - _ext = compile_cuda_extension("allgather_pull_ext", CUDA_SRC) - _ext_compiled = True - return _ext - -_symm_cache = {} - -def _get_symm_state(n: int, dtype: torch.dtype, device: torch.device): - key = (n, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(n, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, ptrs_tensor) - _symm_cache[key] = res - return res - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - if not dist.is_initialized() or dist.get_world_size() == 1: - return tensor.unsqueeze(0).clone() - - tensor = tensor.contiguous() - n = tensor.numel() - world_size = dist.get_world_size() - rank = dist.get_rank() - - _get_ext() - - out_shape = (world_size,) + tensor.shape - out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) - - if n == 0: - return out - - buf, hdl, ptrs_tensor = _get_symm_state(n, tensor.dtype, tensor.device) - - # Start Device-To-Device copy into the local symmetric buffer - buf.copy_(tensor.view(-1)) - - # Overlap local chunk's placement into output while coordinating peers - out[rank].copy_(tensor) - - # Block local streams until all peers' symmetric buffers are fully visible - hdl.barrier(channel=0) - - # Direct UVA pull of remote chunks via NVLink into the final allocation - bytes_per_rank = n * tensor.element_size() - _get_ext().launch_allgather(ptrs_tensor, out, bytes_per_rank, world_size, rank) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_cuda.py deleted file mode 100755 index 63c3b27..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/30_moe_epgroupgemm_lora_backward_cuda.py +++ /dev/null @@ -1,516 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Blockwise barrier definitions -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 1u); -} - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier_relaxed( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_relaxed(send_addr); - wait_signal_relaxed(wait_addr); -} - -__device__ void blockwise_barrier_acq_rel( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -// --------------------------------------------------------------------------- -// NVSwitch Multimem ALLREDUCE -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4( - const uint64_t* addr, - uint32_t& r0, - uint32_t& r1, - uint32_t& r2, - uint32_t& r3 -) { - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(addr) - : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4( - const uint64_t* addr, - uint32_t x, - uint32_t y, - uint32_t z, - uint32_t w -) { - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) - : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, - int world_size, - int rank, - int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = - (numel_128 + (int64_t)world_size - 1) / (int64_t)world_size; - - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = (int64_t)block_id * (int64_t)block_stride; - block_start < numel_per_rank; - block_start += (int64_t)num_programs * (int64_t)block_stride) - { - const int64_t offsets = block_start + (int64_t)tid; - if (offsets >= numel_per_rank) { - continue; - } - const int64_t idx = (int64_t)rank * numel_per_rank + offsets; - uint64_t* ptrs = - reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -// --------------------------------------------------------------------------- -// Fallback Peer-Pointer ALLREDUCE -// --------------------------------------------------------------------------- - -template -__global__ void allreduce_sum_kernel( - const long long* __restrict__ ptrs, - T* __restrict__ out, - int world_size, - int64_t n -); - -template<> -__global__ void allreduce_sum_kernel( - const long long* __restrict__ ptrs, - at::BFloat16* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __nv_bfloat16* src = (const __nv_bfloat16*)ptrs[r]; - sum += __bfloat162float(src[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -template<> -__global__ void allreduce_sum_kernel( - const long long* __restrict__ ptrs, - float* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* src = (const float*)ptrs[r]; - sum += src[idx]; - } - out[idx] = sum; - } -} - -// --------------------------------------------------------------------------- -// Pack and Unpack Kernels -// --------------------------------------------------------------------------- - -template -__global__ void pack_3_kernel( - const T* __restrict__ in1, int n1, - const T* __restrict__ in2, int n2, - const T* __restrict__ in3, int n3, - T* __restrict__ out -) { - if (blockIdx.y == 0) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n1) out[idx] = in1[idx]; - } else if (blockIdx.y == 1) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n2) out[n1 + idx] = in2[idx]; - } else if (blockIdx.y == 2) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n3) out[n1 + n2 + idx] = in3[idx]; - } -} - -template -__global__ void unpack_3_kernel( - const T* __restrict__ in, - int n1, T* __restrict__ out1, - int n2, T* __restrict__ out2, - int n3, T* __restrict__ out3 -) { - if (blockIdx.y == 0) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n1) out1[idx] = in[idx]; - } else if (blockIdx.y == 1) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n2) out2[idx] = in[n1 + idx]; - } else if (blockIdx.y == 2) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n3) out3[idx] = in[n1 + n2 + idx]; - } -} - -// --------------------------------------------------------------------------- -// Extension Bindings -// --------------------------------------------------------------------------- - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, - torch::Tensor signal_pad_ptrs_tensor, - int64_t numel_128, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride -) { - const uint64_t* d_signal = - reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, - d_signal, - numel_128, - world_size, - rank, - block_stride); -} - -void launch_allreduce( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t n, - int dtype_enum -) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (dtype_enum == 0) { - allreduce_sum_kernel<<>>( - d_ptrs, out.data_ptr(), world_size, n); - } else if (dtype_enum == 1) { - allreduce_sum_kernel<<>>( - d_ptrs, out.data_ptr(), world_size, n); - } -} - -void launch_pack_3( - torch::Tensor t1, torch::Tensor t2, torch::Tensor t3, torch::Tensor out) { - int n1 = t1.numel(); - int n2 = t2.numel(); - int n3 = t3.numel(); - int max_n = std::max(n1, std::max(n2, n3)); - int threads = 256; - int blocks_x = (max_n + threads - 1) / threads; - if (blocks_x == 0) blocks_x = 1; - dim3 blocks(blocks_x, 3); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (t1.dtype() == torch::kBFloat16) { - pack_3_kernel<<>>( - t1.data_ptr(), n1, - t2.data_ptr(), n2, - t3.data_ptr(), n3, - out.data_ptr() - ); - } else if (t1.dtype() == torch::kFloat32) { - pack_3_kernel<<>>( - t1.data_ptr(), n1, - t2.data_ptr(), n2, - t3.data_ptr(), n3, - out.data_ptr() - ); - } -} - -void launch_unpack_3( - torch::Tensor in, torch::Tensor t1, torch::Tensor t2, torch::Tensor t3) { - int n1 = t1.numel(); - int n2 = t2.numel(); - int n3 = t3.numel(); - int max_n = std::max(n1, std::max(n2, n3)); - int threads = 256; - int blocks_x = (max_n + threads - 1) / threads; - if (blocks_x == 0) blocks_x = 1; - dim3 blocks(blocks_x, 3); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (t1.dtype() == torch::kBFloat16) { - unpack_3_kernel<<>>( - in.data_ptr(), - n1, t1.data_ptr(), - n2, t2.data_ptr(), - n3, t3.data_ptr() - ); - } else if (t1.dtype() == torch::kFloat32) { - unpack_3_kernel<<>>( - in.data_ptr(), - n1, t1.data_ptr(), - n2, t2.data_ptr(), - n3, t3.data_ptr() - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16); - m.def("launch_allreduce", &launch_allreduce); - m.def("launch_pack_3", &launch_pack_3); - m.def("launch_unpack_3", &launch_unpack_3); -} -''' - -_ext = None - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_ep_lora_allreduce", CUDA_SRC) - return _ext - - -WARP_SIZE = 32 -MAX_NUM_BLOCKS = 4 -MAX_BLOCK_SIZE = 1024 -BYTES_PER_THREAD = 16 - - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = BYTES_PER_THREAD // 2 # bf16 elements per thread - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - - if num_threads < MAX_BLOCK_SIZE: - block_size = 32 # Minimum bounds to prevent deadlock on blockwise barrier subsets - while block_size < num_threads: - block_size *= 2 - num_blocks = 1 - else: - block_size = MAX_BLOCK_SIZE - num_blocks = min( - (num_threads + MAX_BLOCK_SIZE - 1) // MAX_BLOCK_SIZE, - MAX_NUM_BLOCKS, - ) - return num_blocks, block_size, block_size - - -_resource_cache = {} - - -def _get_resources(padded_n: int, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (padded_n, dtype, device, id(group)) - if key in _resource_cache: - return _resource_cache[key] - - # Initialize symmetrically mapping tensors with empty defaults; explicit zero guarantees - # out-of-bounds padded values won't corrupt the eventual reduce result. - buf = symm_mem.empty(padded_n, device=device, dtype=dtype) - buf.zero_() - hdl = symm_mem.rendezvous(buf, group=group) - - out_buf = torch.empty(padded_n, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor, out_buf) - _resource_cache[key] = res - return res - - -@torch.no_grad() -def solution( - grad_fc1_1_lora_A: torch.Tensor, - grad_fc1_2_lora_A: torch.Tensor, - grad_fc2_lora_B: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size <= 1: - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B - - n1 = grad_fc1_1_lora_A.numel() - n2 = grad_fc1_2_lora_A.numel() - n3 = grad_fc2_lora_B.numel() - total_n = n1 + n2 + n3 - - if total_n == 0: - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B - - dtype = grad_fc1_1_lora_A.dtype - device = grad_fc1_1_lora_A.device - - # Pad payload perfectly onto world_size * 16-byte boundaries so the fallback - # memory instructions and multimem accesses do not fault. - chunk_size = world_size * 8 - padded_n = ((total_n + chunk_size - 1) // chunk_size) * chunk_size - - buf, hdl, ptrs_tensor, out_buf = _get_resources(padded_n, dtype, device, group) - - c1 = grad_fc1_1_lora_A.is_contiguous() - c2 = grad_fc1_2_lora_A.is_contiguous() - c3 = grad_fc2_lora_B.is_contiguous() - - t1 = grad_fc1_1_lora_A if c1 else grad_fc1_1_lora_A.contiguous() - t2 = grad_fc1_2_lora_A if c2 else grad_fc1_2_lora_A.contiguous() - t3 = grad_fc2_lora_B if c3 else grad_fc2_lora_B.contiguous() - - ext = _get_ext() - ext.launch_pack_3(t1, t2, t3, buf) - - rank = dist.get_rank(group) - multicast_ptr = getattr(hdl, 'multicast_ptr', 0) - use_multimem = (multicast_ptr != 0 and dtype == torch.bfloat16) - - if use_multimem: - numel_128 = padded_n // 8 - num_blocks, block_size, block_stride = _multimem_launch_config(padded_n, world_size) - - dist.barrier(group=group) - signal_dev = hdl.signal_pad_ptrs_dev - - ext.launch_multimem_allreduce_bf16( - multicast_ptr, - signal_dev, - numel_128, - world_size, - rank, - num_blocks, - block_size, - block_stride, - ) - unpack_buf = buf - else: - hdl.barrier(channel=0) - dtype_enum = 0 if dtype == torch.bfloat16 else 1 - ext.launch_allreduce(ptrs_tensor, out_buf, total_n, dtype_enum) - hdl.barrier(channel=0) - unpack_buf = out_buf - - ext.launch_unpack_3(unpack_buf, t1, t2, t3) - - if not c1: grad_fc1_1_lora_A.copy_(t1) - if not c2: grad_fc1_2_lora_A.copy_(t2) - if not c3: grad_fc2_lora_B.copy_(t3) - - return grad_fc1_1_lora_A, grad_fc1_2_lora_A, grad_fc2_lora_B \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_cuda.py deleted file mode 100755 index 52150c8..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/31_fused_moe_fwd_cuda.py +++ /dev/null @@ -1,327 +0,0 @@ -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void uva_push_kernel( - const T* __restrict__ src_data, - const int32_t* __restrict__ sorted_indices, - const int32_t* __restrict__ counts, - const int32_t* __restrict__ src_offsets, - const int32_t* __restrict__ dest_ranks, - const int32_t* __restrict__ dest_offsets, - const int64_t* __restrict__ dest_buf_ptrs, - int num_chunks, - int hidden_dim -) { - int chunk_idx = blockIdx.y; - if (chunk_idx >= num_chunks) return; - - int count = counts[chunk_idx]; - if (count == 0) return; - - int src_offset = src_offsets[chunk_idx]; - int dest_rank = dest_ranks[chunk_idx]; - int dest_offset = dest_offsets[chunk_idx]; - - T* dest_buf = reinterpret_cast(dest_buf_ptrs[dest_rank]); - - int total_elements = count * hidden_dim; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - - for (int i = tid; i < total_elements; i += stride) { - int token_idx = i / hidden_dim; - int dim_idx = i % hidden_dim; - - int actual_src_token = sorted_indices ? sorted_indices[src_offset + token_idx] : (src_offset + token_idx); - dest_buf[(dest_offset + token_idx) * hidden_dim + dim_idx] = src_data[actual_src_token * hidden_dim + dim_idx]; - } -} - -void uva_push( - torch::Tensor src_data, - std::optional sorted_indices, - torch::Tensor counts, - torch::Tensor src_offsets, - torch::Tensor dest_ranks, - torch::Tensor dest_offsets, - torch::Tensor dest_buf_ptrs, - int num_chunks, - int hidden_dim -) { - const int32_t* idxs = sorted_indices.has_value() ? sorted_indices.value().data_ptr() : nullptr; - const int32_t* c = counts.data_ptr(); - const int32_t* s_off = src_offsets.data_ptr(); - const int32_t* d_ranks = dest_ranks.data_ptr(); - const int32_t* d_off = dest_offsets.data_ptr(); - const int64_t* d_ptrs = dest_buf_ptrs.data_ptr(); - - int threads = 256; - int blocks_x = 16; - dim3 blocks(blocks_x, num_chunks); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (src_data.dtype() == torch::kBFloat16) { - uva_push_kernel<__nv_bfloat16><<>>( - reinterpret_cast(src_data.data_ptr()), - idxs, c, s_off, d_ranks, d_off, d_ptrs, num_chunks, hidden_dim - ); - } else if (src_data.dtype() == torch::kFloat32) { - uva_push_kernel<<>>( - src_data.data_ptr(), - idxs, c, s_off, d_ranks, d_off, d_ptrs, num_chunks, hidden_dim - ); - } else { - TORCH_CHECK(false, "Unsupported dtype"); - } -} - -__global__ void gather_N_kernel( - int32_t* N_matrix, - const int64_t* prep_ptrs, - int num_experts, - int world_size -) { - int r = blockIdx.x; - int e = threadIdx.x; - if (r < world_size && e < num_experts) { - int32_t* remote = reinterpret_cast(prep_ptrs[r]); - N_matrix[r * num_experts + e] = remote[e]; - } -} - -void gather_N(torch::Tensor N_matrix, torch::Tensor prep_ptrs, int num_experts, int world_size) { - gather_N_kernel<<>>( - N_matrix.data_ptr(), prep_ptrs.data_ptr(), num_experts, world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("uva_push", &uva_push, "UVA Fused Push"); - m.def("gather_N", &gather_N, "UVA SymmMem AllGather"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("uva_moe_ext", CUDA_SRC) - return _ext - -_moe_symm_cache = None -def _get_buffers(max_tokens: int, hidden_dim: int, world_size: int, num_experts: int, device: torch.device, dtype: torch.dtype): - global _moe_symm_cache - key = (max_tokens, hidden_dim, world_size, num_experts, dtype) - if _moe_symm_cache is not None and _moe_symm_cache.get('key') == key: - return _moe_symm_cache - - prep_buf = symm_mem.empty((num_experts,), dtype=torch.int32, device=device) - prep_hdl = symm_mem.rendezvous(prep_buf, dist.group.WORLD) - - fwd_recv = symm_mem.empty((max_tokens, hidden_dim), dtype=dtype, device=device) - fwd_recv_hdl = symm_mem.rendezvous(fwd_recv, dist.group.WORLD) - - post_recv = symm_mem.empty((max_tokens, hidden_dim), dtype=dtype, device=device) - post_recv_hdl = symm_mem.rendezvous(post_recv, dist.group.WORLD) - - bwd_recv = symm_mem.empty((max_tokens, hidden_dim), dtype=dtype, device=device) - bwd_recv_hdl = symm_mem.rendezvous(bwd_recv, dist.group.WORLD) - - bwd_expert = symm_mem.empty((max_tokens, hidden_dim), dtype=dtype, device=device) - bwd_expert_hdl = symm_mem.rendezvous(bwd_expert, dist.group.WORLD) - - def get_ptrs(hdl): - return torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _moe_symm_cache = { - 'key': key, - 'prep': (prep_buf, prep_hdl, get_ptrs(prep_hdl)), - 'fwd': (fwd_recv, fwd_recv_hdl, get_ptrs(fwd_recv_hdl)), - 'post': (post_recv, post_recv_hdl, get_ptrs(post_recv_hdl)), - 'bwd': (bwd_recv, bwd_recv_hdl, get_ptrs(bwd_recv_hdl)), - 'bwd_exp': (bwd_expert, bwd_expert_hdl, get_ptrs(bwd_expert_hdl)), - } - return _moe_symm_cache - -def compute_routing_tables(N_matrix: torch.Tensor, num_experts: int, rank: int, world_size: int): - E_loc = num_experts // world_size - N = N_matrix.cpu().tolist() - - fwd_counts, fwd_src_offsets, fwd_dest_ranks, fwd_dest_offsets = [], [], [], [] - for e_glob in range(num_experts): - j, e_loc = e_glob // E_loc, e_glob % E_loc - base_offset = sum(N[k][j * E_loc + e] for e in range(e_loc) for k in range(world_size)) - dest_offset = base_offset + sum(N[k][e_glob] for k in range(rank)) - src_offset = sum(N[rank][e] for e in range(e_glob)) - - fwd_counts.append(N[rank][e_glob]) - fwd_src_offsets.append(src_offset) - fwd_dest_ranks.append(j) - fwd_dest_offsets.append(dest_offset) - - inv_counts, inv_src_offsets, inv_dest_ranks, inv_dest_offsets = [], [], [], [] - for e_loc in range(E_loc): - e_glob = rank * E_loc + e_loc - base_offset = sum(N[k][rank * E_loc + e] for e in range(e_loc) for k in range(world_size)) - for r in range(world_size): - src_offset = base_offset + sum(N[k][e_glob] for k in range(r)) - dest_offset = sum(N[r][e] for e in range(e_glob)) - - inv_counts.append(N[r][e_glob]) - inv_src_offsets.append(src_offset) - inv_dest_ranks.append(r) - inv_dest_offsets.append(dest_offset) - - total_recv = sum(N[k][rank * E_loc + e] for e in range(E_loc) for k in range(world_size)) - - return ( - torch.tensor(fwd_counts, dtype=torch.int32, device='cuda'), - torch.tensor(fwd_src_offsets, dtype=torch.int32, device='cuda'), - torch.tensor(fwd_dest_ranks, dtype=torch.int32, device='cuda'), - torch.tensor(fwd_dest_offsets, dtype=torch.int32, device='cuda'), - torch.tensor(inv_counts, dtype=torch.int32, device='cuda'), - torch.tensor(inv_src_offsets, dtype=torch.int32, device='cuda'), - torch.tensor(inv_dest_ranks, dtype=torch.int32, device='cuda'), - torch.tensor(inv_dest_offsets, dtype=torch.int32, device='cuda'), - total_recv - ) - -class PreAll2All(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, sorted_indices, fwd_tables, inv_tables, fwd_symm, bwd_symm, total_recv, hidden_dim): - fwd_buf, fwd_hdl, fwd_ptrs = fwd_symm - bwd_buf, bwd_hdl, bwd_ptrs = bwd_symm - - ctx.save_for_backward(sorted_indices) - ctx.inv_tables = inv_tables - ctx.bwd_buf, ctx.bwd_hdl, ctx.bwd_ptrs = bwd_buf, bwd_hdl, bwd_ptrs - ctx.hidden_dim, ctx.num_tokens = hidden_dim, hidden_states.size(0) - - counts, src_offsets, dest_ranks, dest_offsets = fwd_tables - if counts.size(0) > 0: - _get_ext().uva_push(hidden_states, sorted_indices, counts, src_offsets, dest_ranks, dest_offsets, fwd_ptrs, counts.size(0), hidden_dim) - fwd_hdl.barrier(channel=0) - return fwd_buf[:total_recv].clone() - - @staticmethod - def backward(ctx, grad_output): - sorted_indices, = ctx.saved_tensors - counts, src_offsets, dest_ranks, dest_offsets = ctx.inv_tables - if counts.size(0) > 0: - _get_ext().uva_push(grad_output.contiguous(), None, counts, src_offsets, dest_ranks, dest_offsets, ctx.bwd_ptrs, counts.size(0), ctx.hidden_dim) - ctx.bwd_hdl.barrier(channel=0) - - grad_hidden_states = torch.zeros(ctx.num_tokens, ctx.hidden_dim, dtype=grad_output.dtype, device=grad_output.device) - grad_hidden_states.index_put_((sorted_indices,), ctx.bwd_buf[:sorted_indices.size(0)], accumulate=True) - return grad_hidden_states, None, None, None, None, None, None, None - -class PostAll2All(torch.autograd.Function): - @staticmethod - def forward(ctx, expert_outputs, inv_tables, fwd_tables, post_symm, bwd_exp_symm, total_sent, hidden_dim): - post_buf, post_hdl, post_ptrs = post_symm - bwd_exp_buf, bwd_exp_hdl, bwd_exp_ptrs = bwd_exp_symm - - ctx.fwd_tables = fwd_tables - ctx.bwd_exp_buf, ctx.bwd_exp_hdl, ctx.bwd_exp_ptrs = bwd_exp_buf, bwd_exp_hdl, bwd_exp_ptrs - ctx.hidden_dim, ctx.total_recv = hidden_dim, expert_outputs.size(0) - - counts, src_offsets, dest_ranks, dest_offsets = inv_tables - if counts.size(0) > 0: - _get_ext().uva_push(expert_outputs.contiguous(), None, counts, src_offsets, dest_ranks, dest_offsets, post_ptrs, counts.size(0), hidden_dim) - post_hdl.barrier(channel=0) - return post_buf[:total_sent].clone() - - @staticmethod - def backward(ctx, grad_output): - counts, src_offsets, dest_ranks, dest_offsets = ctx.fwd_tables - if counts.size(0) > 0: - _get_ext().uva_push(grad_output.contiguous(), None, counts, src_offsets, dest_ranks, dest_offsets, ctx.bwd_exp_ptrs, counts.size(0), ctx.hidden_dim) - ctx.bwd_exp_hdl.barrier(channel=0) - return ctx.bwd_exp_buf[:ctx.total_recv].clone(), None, None, None, None, None, None - -def expert_forward( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, -) -> torch.Tensor: - gate = torch.nn.functional.silu(gate_proj(x)) - up = up_proj(x) - return down_proj(gate * up) - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - _get_ext() - group = group or dist.group.WORLD - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - device = hidden_states.device - hidden_dim = hidden_states.size(-1) - num_tokens = hidden_states.view(-1, hidden_dim).size(0) - dtype = hidden_states.dtype - - max_tokens = world_size * num_tokens * top_k - symm_cache = _get_buffers(max_tokens, hidden_dim, world_size, num_experts, device, dtype) - - # Router - router_logits = torch.nn.functional.linear(hidden_states.view(-1, hidden_dim), gate_weight, gate_bias) - routing_weights, selected_experts = torch.topk(torch.softmax(router_logits, dim=-1), top_k, dim=-1) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0) - - # Preprocess SymmMem AllGather - prep_buf, prep_hdl, prep_ptrs = symm_cache['prep'] - prep_buf.copy_(expert_mask.sum(dim=(1, 2)).to(torch.int32)) - prep_hdl.barrier(channel=0) - - N_matrix = torch.empty((world_size, num_experts), dtype=torch.int32, device=device) - _get_ext().gather_N(N_matrix, prep_ptrs, num_experts, world_size) - - tables = compute_routing_tables(N_matrix, num_experts, rank, world_size) - fwd_tables, inv_tables, total_recv = tables[0:4], tables[4:8], tables[8] - - # Sorting config - routing_map = expert_mask.sum(dim=1).bool() - sorted_indices = torch.arange(num_tokens, device=device, dtype=torch.int32).unsqueeze(0).expand(num_experts, -1).masked_select(routing_map) - - # UVA Token Pre All2All - recv_buf_fwd = PreAll2All.apply( - hidden_states, sorted_indices, fwd_tables, inv_tables, - symm_cache['fwd'], symm_cache['bwd'], total_recv, hidden_dim - ) - - # Expert execution - expert_outputs = expert_forward(recv_buf_fwd, gate_proj, up_proj, down_proj) - - # UVA Tokens Post All2All - post_recv_buf = PostAll2All.apply( - expert_outputs, inv_tables, fwd_tables, - symm_cache['post'], symm_cache['bwd_exp'], sorted_indices.size(0), hidden_dim - ) - - # Unpermute - tokens_weight = routing_weights.T.contiguous().masked_select(routing_map) - tokens = post_recv_buf * tokens_weight.unsqueeze(-1) - unpermuted_tokens = torch.zeros_like(hidden_states) - unpermuted_tokens.index_put_((sorted_indices,), tokens, accumulate=True) - - return unpermuted_tokens \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_cuda.py deleted file mode 100755 index 81210af..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/32_fused_moe_fwd_lora_cuda.py +++ /dev/null @@ -1,379 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import List, Optional, Tuple, Union - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void all_gather_counts_kernel( - const int* __restrict__ local_counts, - const long long* __restrict__ peer_ptrs, - int rank, - int num_experts, - int world_size -) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < num_experts) { - int val = local_counts[tid]; - for (int p = 0; p < world_size; p++) { - int* peer_counts = (int*)peer_ptrs[p]; - peer_counts[rank * num_experts + tid] = val; - } - } -} - -template -__global__ void dispatch_tokens_kernel( - const scalar_t* __restrict__ hidden_states, - const long long* __restrict__ peer_tokens_ptrs, - const int* __restrict__ owner, - const int* __restrict__ dest_offset, - int hidden_dim, - int top_k -) { - int token_idx = blockIdx.x; - int k_idx = blockIdx.y; - - int target_rank = owner[token_idx * top_k + k_idx]; - int target_offset = dest_offset[token_idx * top_k + k_idx]; - scalar_t* dest_ptr = (scalar_t*)peer_tokens_ptrs[target_rank]; - - for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - dest_ptr[target_offset * hidden_dim + i] = hidden_states[token_idx * hidden_dim + i]; - } -} - -template <> -__global__ void dispatch_tokens_kernel<__nv_bfloat16>( - const __nv_bfloat16* __restrict__ hidden_states, - const long long* __restrict__ peer_tokens_ptrs, - const int* __restrict__ owner, - const int* __restrict__ dest_offset, - int hidden_dim, - int top_k -) { - int token_idx = blockIdx.x; - int k_idx = blockIdx.y; - - int target_rank = owner[token_idx * top_k + k_idx]; - int target_offset = dest_offset[token_idx * top_k + k_idx]; - __nv_bfloat16* dest_ptr = (__nv_bfloat16*)peer_tokens_ptrs[target_rank]; - - for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - dest_ptr[target_offset * hidden_dim + i] = hidden_states[token_idx * hidden_dim + i]; - } -} - -template -__global__ void pull_tokens_kernel( - const long long* __restrict__ peer_expert_out_ptrs, - const int* __restrict__ owner, - const int* __restrict__ dest_offset, - const float* __restrict__ routing_weights, - scalar_t* __restrict__ out, - int hidden_dim, - int top_k -) { - int token_idx = blockIdx.x; - - for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - float accum = 0.0f; - for (int k = 0; k < top_k; k++) { - int target_rank = owner[token_idx * top_k + k]; - int target_offset = dest_offset[token_idx * top_k + k]; - float weight = routing_weights[token_idx * top_k + k]; - - const float* src_ptr = (const float*)peer_expert_out_ptrs[target_rank]; - float val = src_ptr[target_offset * hidden_dim + i]; - accum += val * weight; - } - out[token_idx * hidden_dim + i] = accum; - } -} - -template <> -__global__ void pull_tokens_kernel<__nv_bfloat16>( - const long long* __restrict__ peer_expert_out_ptrs, - const int* __restrict__ owner, - const int* __restrict__ dest_offset, - const float* __restrict__ routing_weights, - __nv_bfloat16* __restrict__ out, - int hidden_dim, - int top_k -) { - int token_idx = blockIdx.x; - - for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - float accum = 0.0f; - for (int k = 0; k < top_k; k++) { - int target_rank = owner[token_idx * top_k + k]; - int target_offset = dest_offset[token_idx * top_k + k]; - float weight = routing_weights[token_idx * top_k + k]; - - const __nv_bfloat16* src_ptr = (const __nv_bfloat16*)peer_expert_out_ptrs[target_rank]; - float val = __bfloat162float(src_ptr[target_offset * hidden_dim + i]); - accum += val * weight; - } - out[token_idx * hidden_dim + i] = __float2bfloat16(accum); - } -} - -void launch_all_gather_counts( - torch::Tensor local_counts, - torch::Tensor peer_ptrs, - int rank, - int num_experts, - int world_size -) { - int threads = std::min(num_experts, 1024); - int blocks = (num_experts + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - all_gather_counts_kernel<<>>( - local_counts.data_ptr(), - (const long long*)peer_ptrs.data_ptr(), - rank, - num_experts, - world_size - ); -} - -void launch_dispatch_tokens( - torch::Tensor hidden_states, - torch::Tensor peer_tokens_ptrs, - torch::Tensor owner, - torch::Tensor dest_offset, - int top_k -) { - int num_tokens = hidden_states.size(0); - int hidden_dim = hidden_states.size(1); - - dim3 grid(num_tokens, top_k); - int threads = std::min(hidden_dim, 1024); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (hidden_states.dtype() == torch::kBFloat16) { - dispatch_tokens_kernel<__nv_bfloat16><<>>( - (__nv_bfloat16*)hidden_states.data_ptr(), - (const long long*)peer_tokens_ptrs.data_ptr(), - owner.data_ptr(), - dest_offset.data_ptr(), - hidden_dim, - top_k - ); - } else { - dispatch_tokens_kernel<<>>( - hidden_states.data_ptr(), - (const long long*)peer_tokens_ptrs.data_ptr(), - owner.data_ptr(), - dest_offset.data_ptr(), - hidden_dim, - top_k - ); - } -} - -void launch_pull_tokens( - torch::Tensor peer_expert_out_ptrs, - torch::Tensor owner, - torch::Tensor dest_offset, - torch::Tensor routing_weights, - torch::Tensor out, - int top_k -) { - int num_tokens = out.size(0); - int hidden_dim = out.size(1); - - int threads = std::min(hidden_dim, 1024); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (out.dtype() == torch::kBFloat16) { - pull_tokens_kernel<__nv_bfloat16><<>>( - (const long long*)peer_expert_out_ptrs.data_ptr(), - owner.data_ptr(), - dest_offset.data_ptr(), - routing_weights.data_ptr(), - (__nv_bfloat16*)out.data_ptr(), - hidden_dim, - top_k - ); - } else { - pull_tokens_kernel<<>>( - (const long long*)peer_expert_out_ptrs.data_ptr(), - owner.data_ptr(), - dest_offset.data_ptr(), - routing_weights.data_ptr(), - out.data_ptr(), - hidden_dim, - top_k - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_all_gather_counts", &launch_all_gather_counts); - m.def("launch_dispatch_tokens", &launch_dispatch_tokens); - m.def("launch_pull_tokens", &launch_pull_tokens); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_moe_lora_uva", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(world_size, num_experts, max_tokens_per_rank, hidden_dim, device, dtype): - key = (world_size, num_experts, max_tokens_per_rank, hidden_dim, device, dtype) - if key in _symm_cache: - return _symm_cache[key] - - counts_buf = symm_mem.empty((world_size, num_experts), dtype=torch.int32, device=device) - counts_hdl = symm_mem.rendezvous(counts_buf, dist.group.WORLD) - counts_ptrs = torch.tensor(counts_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - tokens_buf = symm_mem.empty((max_tokens_per_rank, hidden_dim), dtype=dtype, device=device) - tokens_hdl = symm_mem.rendezvous(tokens_buf, dist.group.WORLD) - tokens_ptrs = torch.tensor(tokens_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - expert_out_buf = symm_mem.empty((max_tokens_per_rank, hidden_dim), dtype=dtype, device=device) - expert_out_hdl = symm_mem.rendezvous(expert_out_buf, dist.group.WORLD) - expert_out_ptrs = torch.tensor(expert_out_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - res = (counts_buf, counts_hdl, counts_ptrs, - tokens_buf, tokens_hdl, tokens_ptrs, - expert_out_buf, expert_out_hdl, expert_out_ptrs) - _symm_cache[key] = res - return res - -def expert_forward_lora( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - lora_gate_A: torch.Tensor, - lora_gate_B: torch.Tensor, - lora_up_A: torch.Tensor, - lora_up_B: torch.Tensor, - lora_down_A: torch.Tensor, - lora_down_B: torch.Tensor, -) -> torch.Tensor: - F = torch.nn.functional - xa_g = F.linear(x, lora_gate_A) - gate_x = gate_proj(x) + F.linear(xa_g, lora_gate_B) - gate = F.silu(gate_x) - xa_u = F.linear(x, lora_up_A) - up = up_proj(x) + F.linear(xa_u, lora_up_B) - y = gate * up - xa_d = F.linear(y, lora_down_A) - return down_proj(y) + F.linear(xa_d, lora_down_B) - -@torch.no_grad() -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - lora_gate_A: torch.Tensor, - lora_gate_B: torch.Tensor, - lora_up_A: torch.Tensor, - lora_up_B: torch.Tensor, - lora_down_A: torch.Tensor, - lora_down_B: torch.Tensor, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - device = hidden_states.device - dtype = hidden_states.dtype - ext = _get_ext() - - hidden_dim = hidden_states.size(-1) - org_shape = hidden_states.shape - hidden_states_flat = hidden_states.reshape(-1, hidden_dim).contiguous() - num_tokens = hidden_states_flat.size(0) - - num_local_experts = num_experts // world_size - max_tokens_per_rank = max(num_tokens * top_k * world_size, 65536) - - # Initialize P2P routing symmetric allocations - (counts_buf, counts_hdl, counts_ptrs, - tokens_buf, tokens_hdl, tokens_ptrs, - expert_out_buf, expert_out_hdl, expert_out_ptrs) = _get_symm_state( - world_size, num_experts, max_tokens_per_rank, hidden_dim, device, dtype - ) - - # 1. Routing - router_logits = torch.nn.functional.linear(hidden_states_flat, gate_weight, gate_bias) - routing_weights, selected_experts = torch.topk(torch.softmax(router_logits, dim=-1), top_k, dim=-1) - - # 2. Histogram & Global Scatter (Custom AllGather) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).sum(dim=1) - local_counts = expert_mask.sum(dim=0).to(torch.int32) - local_idx_matrix = expert_mask.cumsum(dim=0, dtype=torch.int32) - 1 - - counts_hdl.barrier(channel=0) - ext.launch_all_gather_counts(local_counts, counts_ptrs, rank, num_experts, world_size) - counts_hdl.barrier(channel=0) - - # 3. Offsets Setup - my_experts_counts = counts_buf[:, rank * num_local_experts : (rank + 1) * num_local_experts] - total_tokens_per_my_expert = my_experts_counts.sum(dim=0) - total_my_tokens = total_tokens_per_my_expert.sum().item() - - reshaped_total = counts_buf.sum(dim=0).view(world_size, num_local_experts) - expert_base_global = torch.cat([ - torch.zeros((world_size, 1), dtype=torch.int32, device=device), - reshaped_total[:, :-1].cumsum(dim=1) - ], dim=1) - - sender_offset = torch.cat([ - torch.zeros((1, num_experts), dtype=torch.int32, device=device), - counts_buf[:-1, :].cumsum(dim=0) - ], dim=0) - - owner = (selected_experts // num_local_experts).to(torch.int32) - le = (selected_experts % num_local_experts).to(torch.int32) - - expert_base_for_selected = expert_base_global[owner, le] - sender_offset_for_selected = sender_offset[rank, selected_experts] - local_idx_for_selected = local_idx_matrix.gather(1, selected_experts) - dest_offset = expert_base_for_selected + sender_offset_for_selected + local_idx_for_selected - - # 4. UAV P2P Direct Dispatch - tokens_hdl.barrier(channel=0) - ext.launch_dispatch_tokens(hidden_states_flat, tokens_ptrs, owner, dest_offset, top_k) - tokens_hdl.barrier(channel=0) - - # 5. Shared Expert Computations (Fused LoRA block) - expert_out_hdl.barrier(channel=0) - if total_my_tokens > 0: - expert_out = expert_forward_lora( - tokens_buf[:total_my_tokens], - gate_proj, up_proj, down_proj, - lora_gate_A, lora_gate_B, lora_up_A, lora_up_B, lora_down_A, lora_down_B - ) - expert_out_buf[:total_my_tokens].copy_(expert_out) - expert_out_hdl.barrier(channel=0) - - # 6. P2P Direct Pull & Accumulate Result - out_flat = torch.empty_like(hidden_states_flat) - ext.launch_pull_tokens(expert_out_ptrs, owner, dest_offset, routing_weights.float().contiguous(), out_flat, top_k) - expert_out_hdl.barrier(channel=0) - - return out_flat.reshape(org_shape) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_cuda.py deleted file mode 100755 index 9ef7b5e..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/34_ulysses_all_to_all_tensor_primitive_cuda.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Optimized all_to_all_tensor for sequence parallelism (Ulysses). - -Strategy: -- Device-side Communication: Uses `torch.distributed._symmetric_memory` to allocate - persistent symmetric input buffers on each rank. -- Compute-Communication Fusion: Instead of allocating intermediate lists of chunk tensors - and launching multi-step NCCL collectives, we launch a single custom CUDA P2P kernel. -- Pull-based P2P over NVLink: The kernel allows each rank to read its required chunks - directly from the symmetric input buffers of all peers. -- Fast Indexing: Multidimensional tensor coordinates are collapsed on the host into a - minimal set of outer loops, leaving the largest possible contiguous innermost dimension - mapped directly to thread blocks for perfectly coalesced memory accesses. -""" - -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -#define MAX_WORLD_SIZE 32 - -template -struct PeerPtrs { - const scalar_t* ptrs[MAX_WORLD_SIZE]; -}; - -struct ShapeStrides { - int64_t shape[8]; - int64_t stride_in[8]; - int64_t stride_out[8]; -}; - -template -__global__ void all_to_all_pull_kernel( - PeerPtrs peers, - scalar_t* __restrict__ out_ptr, - int rank, - int world_size, - int64_t numel_chunk, - int64_t inner_size, - ShapeStrides ss, - int64_t orig_stride_in_scatter, - int64_t orig_stride_out_gather, - int64_t c_sc, - int64_t S_ga -) { - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int p = blockIdx.y; - - if (tid < numel_chunk) { - int64_t outer_idx = tid / inner_size; - int64_t inner_idx = tid % inner_size; - - int64_t temp = outer_idx; - int64_t offset_in = rank * c_sc * orig_stride_in_scatter + inner_idx; - int64_t offset_out = p * S_ga * orig_stride_out_gather + inner_idx; - - #pragma unroll - for (int d = 7; d >= 0; --d) { - int64_t size = ss.shape[d]; - if (size > 1) { - int64_t coord = temp % size; - temp = temp / size; - offset_in += coord * ss.stride_in[d]; - offset_out += coord * ss.stride_out[d]; - } - } - - out_ptr[offset_out] = peers.ptrs[p][offset_in]; - } -} - -void launch_all_to_all_pull( - std::vector peer_ptrs_vec, - torch::Tensor out_tensor, - int rank, - int world_size, - int64_t numel_chunk, - int64_t inner_size, - std::vector outer_shape, - std::vector outer_stride_in, - std::vector outer_stride_out, - int64_t orig_stride_in_scatter, - int64_t orig_stride_out_gather, - int64_t c_sc, - int64_t S_ga -) { - TORCH_CHECK(world_size <= MAX_WORLD_SIZE, "world_size exceeds MAX_WORLD_SIZE"); - - ShapeStrides ss; - for (int i = 0; i < 8; ++i) { - ss.shape[i] = outer_shape[i]; - ss.stride_in[i] = outer_stride_in[i]; - ss.stride_out[i] = outer_stride_out[i]; - } - - int threads = 256; - int blocks_x = (numel_chunk + threads - 1) / threads; - dim3 blocks(blocks_x, world_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (out_tensor.dtype() == torch::kBFloat16) { - PeerPtrs<__nv_bfloat16> peers; - for (int i = 0; i < world_size; ++i) { - peers.ptrs[i] = reinterpret_cast(peer_ptrs_vec[i]); - } - __nv_bfloat16* out_ptr = reinterpret_cast<__nv_bfloat16*>(out_tensor.data_ptr()); - all_to_all_pull_kernel<__nv_bfloat16><<>>( - peers, out_ptr, rank, world_size, numel_chunk, inner_size, ss, - orig_stride_in_scatter, orig_stride_out_gather, c_sc, S_ga - ); - } else if (out_tensor.dtype() == torch::kFloat32) { - PeerPtrs peers; - for (int i = 0; i < world_size; ++i) { - peers.ptrs[i] = reinterpret_cast(peer_ptrs_vec[i]); - } - float* out_ptr = reinterpret_cast(out_tensor.data_ptr()); - all_to_all_pull_kernel<<>>( - peers, out_ptr, rank, world_size, numel_chunk, inner_size, ss, - orig_stride_in_scatter, orig_stride_out_gather, c_sc, S_ga - ); - } else { - TORCH_CHECK(false, "Unsupported dtype: only bfloat16 and float32 are supported."); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_all_to_all_pull", &launch_all_to_all_pull, "Ulysses all-to-all pull kernel"); -} -''' - -_ext = None -_compiled = False - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ulysses_all_to_all_pull", CUDA_SRC) - return _ext - - -def _ensure_compiled(group: dist.ProcessGroup): - global _compiled - if not _compiled: - rank = dist.get_rank(group) - if rank == 0: - _get_ext() - dist.barrier(group) - if rank != 0: - _get_ext() - _compiled = True - - -_symm_cache = {} - - -def _get_symm_state(shape_in: tuple, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (shape_in, dtype, device, id(group)) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape_in, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - peer_ptrs = [int(p) for p in hdl.buffer_ptrs] - - _symm_cache[key] = (buf, hdl, peer_ptrs) - return buf, hdl, peer_ptrs - - -_arg_cache = {} - - -def _get_kernel_args(shape_in: tuple, scatter_dim: int, gather_dim: int, world_size: int): - key = (shape_in, scatter_dim, gather_dim, world_size) - if key in _arg_cache: - return _arg_cache[key] - - shape_out = list(shape_in) - shape_out[scatter_dim] = shape_in[scatter_dim] // world_size - shape_out[gather_dim] = shape_in[gather_dim] * world_size - - def get_strides(shape): - strides = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - strides[i] = strides[i+1] * shape[i+1] - return strides - - stride_in = get_strides(shape_in) - stride_out = get_strides(shape_out) - - chunk_shape = list(shape_in) - chunk_shape[scatter_dim] = shape_in[scatter_dim] // world_size - - new_chunk = [chunk_shape[-1]] - new_stride_in = [stride_in[-1]] - new_stride_out = [stride_out[-1]] - - for d in range(len(chunk_shape)-2, -1, -1): - # Collapse contiguous dimensions avoiding div/mod overhead in kernel - if (stride_in[d] == new_stride_in[0] * new_chunk[0] and - stride_out[d] == new_stride_out[0] * new_chunk[0]): - new_chunk[0] *= chunk_shape[d] - else: - new_chunk.insert(0, chunk_shape[d]) - new_stride_in.insert(0, stride_in[d]) - new_stride_out.insert(0, stride_out[d]) - - inner_size = new_chunk[-1] - - outer_shape = new_chunk[:-1] - outer_stride_in = new_stride_in[:-1] - outer_stride_out = new_stride_out[:-1] - - if len(outer_shape) > 8: - raise ValueError("Too many tensor dimensions after collapsing.") - - while len(outer_shape) < 8: - outer_shape.insert(0, 1) - outer_stride_in.insert(0, 0) - outer_stride_out.insert(0, 0) - - numel_chunk = 1 - for s in chunk_shape: - numel_chunk *= s - - orig_stride_in_scatter = stride_in[scatter_dim] - orig_stride_out_gather = stride_out[gather_dim] - c_sc = chunk_shape[scatter_dim] - S_ga = chunk_shape[gather_dim] - - res = ( - int(numel_chunk), int(inner_size), - [int(x) for x in outer_shape], - [int(x) for x in outer_stride_in], - [int(x) for x in outer_stride_out], - int(orig_stride_in_scatter), int(orig_stride_out_gather), int(c_sc), int(S_ga) - ) - _arg_cache[key] = res - return res - - -@torch.no_grad() -def solution( - x: torch.Tensor, - scatter_dim: int, - gather_dim: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return x.contiguous() - - x = x.contiguous() - - shape_out = list(x.shape) - shape_out[scatter_dim] = x.shape[scatter_dim] // world_size - shape_out[gather_dim] = x.shape[gather_dim] * world_size - - if x.numel() == 0: - return torch.empty(shape_out, dtype=x.dtype, device=x.device) - - _ensure_compiled(group) - rank = dist.get_rank(group) - - buf, hdl, peer_ptrs = _get_symm_state( - tuple(x.shape), x.dtype, x.device, group - ) - - args = _get_kernel_args( - tuple(x.shape), scatter_dim, gather_dim, world_size - ) - - out_tensor = torch.empty(shape_out, dtype=x.dtype, device=x.device) - - # Wait for peers to finish reading from the symmetric buffer of the previous iteration - hdl.barrier(channel=0) - - # Push local chunk to the symmetric buffer for peers to read - buf.copy_(x) - - # Wait for peers to finish writing to their symmetric buffers - hdl.barrier(channel=1) - - # Launch direct fused P2P pulling operations - _get_ext().launch_all_to_all_pull( - peer_ptrs, - out_tensor, - rank, - world_size, - *args - ) - - return out_tensor \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_cuda.py deleted file mode 100755 index 04d3b72..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/35_ulysses_all_gather_into_tensor_primitive_cuda.py +++ /dev/null @@ -1,259 +0,0 @@ -# Strategy: -# 1. Device-side communication: Instead of opaque NCCL rings, we allocate symmetric -# memory for the output tensor and use Hopper's NVLink hardware multicast -# (`multimem.st`) to PUSH each rank's local shard directly into the correct slice -# of all W peers' output buffers simultaneously via a single custom kernel. -# 2. Optimal bandwidth: The 1-to-W hardware broadcast reduces memory reads on the -# sender by Wx compared to standard P2P push, perfectly saturating NVLink. -# 3. Double-buffering & Overlap: We rotate through a pool of symmetric memory buffers. -# This structure safely hides stream synchronization latency, requires only one -# device-side barrier per invocation, and allows overlap with unrelated downstream ops. - -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Hardware Multicast Push (Hopper multimem.st) -// Broadcasts 16-byte chunks (e.g. 8x bfloat16) to all ranks simultaneously. -// --------------------------------------------------------------------------- -__global__ void multimem_push_16B( - const uint4* __restrict__ local_x, - uint64_t multicast_ptr, - int64_t numel_16b, - int64_t offset_16b -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel_16b; idx += gridDim.x * blockDim.x) { - uint4 val = local_x[idx]; - uint64_t dst = multicast_ptr + (offset_16b + idx) * 16; - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - :: "l"(dst), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory" - ); - } -} - -// --------------------------------------------------------------------------- -// Fallback: Standard P2P Push (For older architectures or missing NVSwitch) -// --------------------------------------------------------------------------- -__global__ void p2p_push_16B_kernel( - const uint4* __restrict__ local_x, - const uint64_t* __restrict__ dst_ptrs, - int world_size, - int64_t numel_16b, - int64_t offset_16b -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel_16b; idx += gridDim.x * blockDim.x) { - uint4 val = local_x[idx]; - #pragma unroll 8 - for (int r = 0; r < world_size; ++r) { - uint4* dst = (uint4*)(dst_ptrs[r] + (offset_16b + idx) * 16); - *dst = val; - } - } -} - -// --------------------------------------------------------------------------- -// Unaligned Fallback (2-byte chunks for bfloat16 / float16 elements) -// --------------------------------------------------------------------------- -__global__ void p2p_push_2B_kernel( - const uint16_t* __restrict__ local_x, - const uint64_t* __restrict__ dst_ptrs, - int world_size, - int64_t numel_2b, - int64_t offset_2b -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel_2b; idx += gridDim.x * blockDim.x) { - uint16_t val = local_x[idx]; - #pragma unroll 8 - for (int r = 0; r < world_size; ++r) { - uint16_t* dst = (uint16_t*)(dst_ptrs[r] + (offset_2b + idx) * 2); - *dst = val; - } - } -} - -void launch_multimem_push( - torch::Tensor local_x, - uint64_t multicast_ptr, - torch::Tensor ptrs_tensor, - int world_size, - int rank, - int64_t numel_bytes -) { - if (numel_bytes == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - - // Ensure memory pointer and byte sizes are fully 16-byte aligned. - bool is_16b_aligned = (reinterpret_cast(local_x.data_ptr()) % 16 == 0) && - (numel_bytes % 16 == 0); - - if (is_16b_aligned) { - int64_t numel_16b = numel_bytes / 16; - int64_t offset_16b = (rank * numel_bytes) / 16; - int64_t max_blocks = 65535; - int64_t blocks = std::min(max_blocks, (numel_16b + threads - 1) / threads); - multimem_push_16B<<>>( - (const uint4*)local_x.data_ptr(), - multicast_ptr, - numel_16b, - offset_16b - ); - } else { - int64_t numel_2b = numel_bytes / 2; - int64_t offset_2b = (rank * numel_bytes) / 2; - int64_t max_blocks = 65535; - int64_t blocks = std::min(max_blocks, (numel_2b + threads - 1) / threads); - p2p_push_2B_kernel<<>>( - (const uint16_t*)local_x.data_ptr(), - (const uint64_t*)ptrs_tensor.data_ptr(), - world_size, - numel_2b, - offset_2b - ); - } -} - -void launch_p2p_push( - torch::Tensor local_x, - torch::Tensor ptrs_tensor, - int world_size, - int rank, - int64_t numel_bytes -) { - if (numel_bytes == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - - bool is_16b_aligned = (reinterpret_cast(local_x.data_ptr()) % 16 == 0) && - (numel_bytes % 16 == 0); - - if (is_16b_aligned) { - int64_t numel_16b = numel_bytes / 16; - int64_t offset_16b = (rank * numel_bytes) / 16; - int64_t max_blocks = 65535; - int64_t blocks = std::min(max_blocks, (numel_16b + threads - 1) / threads); - p2p_push_16B_kernel<<>>( - (const uint4*)local_x.data_ptr(), - (const uint64_t*)ptrs_tensor.data_ptr(), - world_size, - numel_16b, - offset_16b - ); - } else { - int64_t numel_2b = numel_bytes / 2; - int64_t offset_2b = (rank * numel_bytes) / 2; - int64_t max_blocks = 65535; - int64_t blocks = std::min(max_blocks, (numel_2b + threads - 1) / threads); - p2p_push_2B_kernel<<>>( - (const uint16_t*)local_x.data_ptr(), - (const uint64_t*)ptrs_tensor.data_ptr(), - world_size, - numel_2b, - offset_2b - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_push", &launch_multimem_push, "Multicast push onto symmetric pointers"); - m.def("launch_p2p_push", &launch_p2p_push, "Peer-to-peer push fallback"); -} -''' - -_ext = None -_ext_compiled = False - - -def _get_ext_safe(rank: int, group: dist.ProcessGroup): - global _ext_compiled, _ext - if not _ext_compiled: - # Protect compilation lock races across ranks by staggering the entry - if rank == 0: - _ext = compile_cuda_extension("ulysses_allgather_ext", CUDA_SRC) - dist.barrier(group) - if rank != 0: - _ext = compile_cuda_extension("ulysses_allgather_ext", CUDA_SRC) - _ext_compiled = True - return _ext - - -_symm_cache = {} - - -def _get_symm_output(shape, dtype, device, group): - """ - Allocates and caches a pool of 2 symmetric memory buffers for double buffering. - This safely guarantees isolation between iterations in the pipeline schedule. - """ - key = (tuple(shape), dtype, device) - if key not in _symm_cache: - pool = [] - for _ in range(2): - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - pool.append((buf, hdl, ptrs)) - _symm_cache[key] = { - "pool": pool, - "idx": 0 - } - - cache = _symm_cache[key] - idx = cache["idx"] - buf, hdl, ptrs = cache["pool"][idx] - # Advance to the next buffer for the following invocation - cache["idx"] = (idx + 1) % 2 - return buf, hdl, ptrs - - -@torch.no_grad() -def solution( - x: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return x.contiguous() - - x = x.contiguous() - rank = dist.get_rank(group) - - dim_size = list(x.size()) - # Expand 0-th dimension to hold the entire gathered size - dim_size[0] = dim_size[0] * world_size - - ext = _get_ext_safe(rank, group) - buf, hdl, ptrs_tensor = _get_symm_output(dim_size, x.dtype, x.device, group) - - numel_bytes = x.numel() * x.element_size() - multicast_ptr = int(hdl.multicast_ptr) if hasattr(hdl, 'multicast_ptr') else 0 - - # Push local shard directly to the appropriate slice in all ranks' symmetric outputs. - if multicast_ptr != 0: - ext.launch_multimem_push(x, multicast_ptr, ptrs_tensor, world_size, rank, numel_bytes) - else: - ext.launch_p2p_push(x, ptrs_tensor, world_size, rank, numel_bytes) - - # Queue a device-side stream barrier enforcing complete delivery visibility before clone - hdl.barrier(channel=0) - - # Return an independent instance isolated from future rotations of symmetric buffer - return buf.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_cuda.py deleted file mode 100755 index 4b40de4..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/36_ulysses_all_gather_variable_primitive_cuda.py +++ /dev/null @@ -1,312 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -#define MAX_WS 32 - -struct KernelArgs { - int64_t L_array[MAX_WS]; - int64_t dst_offset[MAX_WS]; - int64_t total_prefix[MAX_WS]; -}; - -template -__global__ void ulysses_allgather_kernel( - const int64_t* __restrict__ data_ptrs, - KernelArgs args, - T* __restrict__ out, - int64_t sum_BC, - int64_t total_elements, - int world_size -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = idx; i < total_elements; i += (int64_t)gridDim.x * blockDim.x) { - int j = 0; - // Small loop to resolve the target rank. total_prefix tracks the flattened boundaries. - while (j < world_size - 1 && i >= args.total_prefix[j + 1]) { - j++; - } - - int64_t local_i = i - args.total_prefix[j]; - int64_t L_j = args.L_array[j]; - - int64_t a = local_i / L_j; - int64_t k = local_i % L_j; - - int64_t src_idx = a * L_j + k; - int64_t dst_idx = a * sum_BC + args.dst_offset[j] + k; - - const T* src = reinterpret_cast(data_ptrs[j]); - out[dst_idx] = src[src_idx]; - } -} - -__global__ void gather_shapes_kernel( - const int64_t* __restrict__ shape_ptrs, - int64_t* __restrict__ gathered_shapes, - int world_size -) { - int rank = blockIdx.x; - int idx = threadIdx.x; - if (rank < world_size && idx < 32) { - const int64_t* src = reinterpret_cast(shape_ptrs[rank]); - gathered_shapes[rank * 32 + idx] = src[idx]; - } -} - -void launch_gather_shapes( - torch::Tensor shape_ptrs, - torch::Tensor gathered_shapes, - int world_size -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_shapes_kernel<<>>( - shape_ptrs.data_ptr(), - gathered_shapes.data_ptr(), - world_size - ); -} - -void launch_ulysses_allgather( - torch::Tensor data_ptrs, - std::vector L_array, - std::vector dst_offset, - std::vector total_prefix, - torch::Tensor out, - int64_t sum_BC, - int64_t total_elements, - int world_size, - int vector_bytes -) { - KernelArgs args; - for (int i = 0; i < world_size; ++i) { - args.L_array[i] = L_array[i]; - args.dst_offset[i] = dst_offset[i]; - args.total_prefix[i] = total_prefix[i]; - } - - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const int64_t* d_ptrs = data_ptrs.data_ptr(); - - // Dynamically dispatch strictly on maximum achievable alignment bandwidth - if (vector_bytes == 16) { - ulysses_allgather_kernel<<>>( - d_ptrs, args, reinterpret_cast(out.data_ptr()), sum_BC, total_elements, world_size); - } else if (vector_bytes == 8) { - ulysses_allgather_kernel<<>>( - d_ptrs, args, reinterpret_cast(out.data_ptr()), sum_BC, total_elements, world_size); - } else if (vector_bytes == 4) { - ulysses_allgather_kernel<<>>( - d_ptrs, args, reinterpret_cast(out.data_ptr()), sum_BC, total_elements, world_size); - } else if (vector_bytes == 2) { - ulysses_allgather_kernel<<>>( - d_ptrs, args, reinterpret_cast(out.data_ptr()), sum_BC, total_elements, world_size); - } else { - ulysses_allgather_kernel<<>>( - d_ptrs, args, reinterpret_cast(out.data_ptr()), sum_BC, total_elements, world_size); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather_shapes", &launch_gather_shapes, "Gather shape info via UVA"); - m.def("launch_ulysses_allgather", &launch_ulysses_allgather, "Ulysses variable allgather custom kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ulysses_allgather_var_ext", CUDA_SRC) - return _ext - - -class SymmCache: - def __init__(self, world_size: int, device: torch.device, dtype: torch.dtype, group: dist.ProcessGroup): - self.world_size = world_size - self.device = device - self.dtype = dtype - self.group = group - - # 32-element buffer allows exchanging up to ~30D tensor shapes. - self.shape_buf = symm_mem.empty(32, dtype=torch.int64, device=device) - self.shape_hdl = symm_mem.rendezvous(self.shape_buf, group) - self.shape_ptrs_dev = torch.tensor(self.shape_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - self.gathered_shapes_dev = torch.empty((world_size, 32), dtype=torch.int64, device=device) - self.gathered_shapes_host = torch.empty((world_size, 32), dtype=torch.int64, pin_memory=True) - self.local_shape_host = torch.empty(32, dtype=torch.int64, pin_memory=True) - - # 1024 elements default fallback size; lazily expands when size spikes. - self.data_capacities = [1024] * world_size - self.data_buf = symm_mem.empty(1024, dtype=dtype, device=device) - self.data_hdl = symm_mem.rendezvous(self.data_buf, group) - self.data_ptrs_dev = torch.tensor(self.data_hdl.buffer_ptrs, dtype=torch.int64, device=device) - -_cache_dict = {} - -def _get_cache(group: dist.ProcessGroup, device: torch.device, dtype: torch.dtype) -> SymmCache: - key = (group, dtype) - if key not in _cache_dict: - _cache_dict[key] = SymmCache(dist.get_world_size(group), device, dtype, group) - return _cache_dict[key] - - -@torch.no_grad() -def solution( - x: torch.Tensor, - gather_dim: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return x.contiguous() - - device = x.device - dtype = x.dtype - x = x.contiguous() - x_dim = x.dim() - - # Ensure correct boundary for reverse-indexing formats - gather_dim = gather_dim % x_dim - assert x_dim <= 30, "Tensor dimensions exceed shape buffer capacity limits" - - cache = _get_cache(group, device, dtype) - rank = dist.get_rank(group) - - # 1. Start shape exchange without synchronizing host - cache.local_shape_host[0] = x_dim - for i, s in enumerate(x.shape): - cache.local_shape_host[i+1] = s - cache.local_shape_host[31] = x.numel() - - cache.shape_buf.copy_(cache.local_shape_host, non_blocking=True) - - # 2. Overlap payload copy natively via async queue while shape barrier resolves - optimistic_copy_done = False - if x.numel() <= cache.data_buf.numel(): - cache.data_buf[:x.numel()].copy_(x.view(-1), non_blocking=True) - optimistic_copy_done = True - - cache.shape_hdl.barrier(channel=0) - - # 3. Harvest device configurations over peer UVA - ext = _get_ext() - ext.launch_gather_shapes(cache.shape_ptrs_dev, cache.gathered_shapes_dev, cache.world_size) - cache.gathered_shapes_host.copy_(cache.gathered_shapes_dev, non_blocking=True) - torch.cuda.current_stream().synchronize() - - # 4. Resolve exact concatenated target configuration and routing - sum_B = 0 - B_array = [] - max_capacity_needed = [0] * cache.world_size - - for i in range(cache.world_size): - B_i = cache.gathered_shapes_host[i, 1 + gather_dim].item() - B_array.append(B_i) - sum_B += B_i - max_capacity_needed[i] = cache.gathered_shapes_host[i, 31].item() - - needs_realloc = False - for i in range(cache.world_size): - if max_capacity_needed[i] > cache.data_capacities[i]: - needs_realloc = True - cache.data_capacities[i] = int(max_capacity_needed[i] * 1.2) # Maintain stable symmetric arrays - - # Re-rendezvous path natively isolated only for rare size-spike spikes - if needs_realloc: - reallocated = False - if cache.data_capacities[rank] > cache.data_buf.numel(): - cache.data_buf = symm_mem.empty(cache.data_capacities[rank], dtype=dtype, device=device) - reallocated = True - - cache.data_hdl = symm_mem.rendezvous(cache.data_buf, group) - cache.data_ptrs_dev = torch.tensor(cache.data_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - if not optimistic_copy_done or reallocated: - cache.data_buf[:x.numel()].copy_(x.view(-1), non_blocking=True) - - cache.data_hdl.barrier(channel=0) - - # 5. Output structure formulation - out_shape = list(x.shape) - out_shape[gather_dim] = sum_B - out = torch.empty(out_shape, dtype=dtype, device=device) - - A = 1 - for s in out_shape[:gather_dim]: - A *= s - C = 1 - for s in out_shape[gather_dim+1:]: - C *= s - - B_prefix = [0] * cache.world_size - total_prefix = [0] * cache.world_size - L_array = [0] * cache.world_size - dst_offset = [0] * cache.world_size - - prefix_b = 0 - prefix_total = 0 - for i in range(cache.world_size): - B_prefix[i] = prefix_b - total_prefix[i] = prefix_total - - L_i = B_array[i] * C - L_array[i] = L_i - dst_offset[i] = prefix_b * C - - prefix_b += B_array[i] - prefix_total += A * L_i - - total_elements = prefix_total - if total_elements == 0: - return out - - # 6. Automatic alignment reduction scaling factor validation - element_size = x.element_size() - max_vf = 16 // element_size - vfs = [max_vf] - while vfs[-1] > 1: - vfs.append(vfs[-1] // 2) - - VF = 1 - for vf in vfs: - if all((b * C) % vf == 0 for b in B_array): - VF = vf - break - - L_array_vf = [l // VF for l in L_array] - dst_offset_vf = [d // VF for d in dst_offset] - total_prefix_vf = [t // VF for t in total_prefix] - sum_BC_vf = (sum_B * C) // VF - total_elements_vf = total_elements // VF - - ext.launch_ulysses_allgather( - cache.data_ptrs_dev, - L_array_vf, - dst_offset_vf, - total_prefix_vf, - out, - sum_BC_vf, - total_elements_vf, - cache.world_size, - VF * element_size - ) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_cuda.py deleted file mode 100755 index a8b7c8f..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/37_ulysses_gather_seq_scatter_heads_cuda.py +++ /dev/null @@ -1,245 +0,0 @@ -""" -Strategy: -1. Replaced NCCL `all_to_all_single` and PyTorch `cat`/`reshape` ops with a direct NVLink PULL kernel using Symmetric Memory. -2. Flattened the complex N-dimensional data routing into a 5D logical tensor offset calculation directly inside the CUDA kernel, eliminating multi-step memory traffic (no intermediate splits, transposes, or concats). -3. The sender simply copies its local data to its symmetric buffer (1 contiguous write), and the receiver directly pulls its required scattered slices into the correct gathered layout (1 NVLink read, 1 contiguous write). This achieves fewer memory ops than the reference all-to-all. -4. Used `uint4` (128-bit) vectorized loads/stores on the inner-most dimension (typically `head_dim` size, highly divisible by 8) for maximum P2P bandwidth utilization. -5. Employs double-buffering for symmetric memory allocations to eliminate read-after-write hazards across consecutive calls without blocking the host stream. -""" - -import math -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch.distributed import ProcessGroup - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void ulysses_pull_kernel( - const uint64_t* __restrict__ peer_ptrs, - void* __restrict__ out_ptr, - int64_t A, int64_t B, int64_t C, int64_t D, int64_t E_vec, - int P, int my_rank, bool scatter_first, int64_t N_out_vec -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N_out_vec) return; - - int64_t e = idx % E_vec; - int64_t tmp = idx / E_vec; - - int64_t rank_j; - int64_t in_idx; - - if (scatter_first) { - // B is scatter (head_dim), D is gather (seq_dim) - // Output shape: [A, B/P, C, D*P, E_vec] - int64_t d_out = tmp % (D * P); - tmp = tmp / (D * P); - int64_t c = tmp % C; - tmp = tmp / C; - int64_t b_out = tmp % (B / P); - int64_t a = tmp / (B / P); - - rank_j = d_out / D; - int64_t d_in = d_out % D; - int64_t b_in = my_rank * (B / P) + b_out; - - in_idx = (((a * B + b_in) * C + c) * D + d_in) * E_vec + e; - } else { - // B is gather (seq_dim), D is scatter (head_dim) - // Output shape: [A, B*P, C, D/P, E_vec] - int64_t d_out = tmp % (D / P); - tmp = tmp / (D / P); - int64_t c = tmp % C; - tmp = tmp / C; - int64_t b_out = tmp % (B * P); - int64_t a = tmp / (B * P); - - rank_j = b_out / B; - int64_t b_in = b_out % B; - int64_t d_in = my_rank * (D / P) + d_out; - - in_idx = (((a * B + b_in) * C + c) * D + d_in) * E_vec + e; - } - - if constexpr (VEC_SIZE == 8) { - const uint4* src = reinterpret_cast(peer_ptrs[rank_j]); - uint4* out = reinterpret_cast(out_ptr); - out[idx] = src[in_idx]; - } else if constexpr (VEC_SIZE == 4) { - const uint2* src = reinterpret_cast(peer_ptrs[rank_j]); - uint2* out = reinterpret_cast(out_ptr); - out[idx] = src[in_idx]; - } else if constexpr (VEC_SIZE == 2) { - const uint32_t* src = reinterpret_cast(peer_ptrs[rank_j]); - uint32_t* out = reinterpret_cast(out_ptr); - out[idx] = src[in_idx]; - } else { - const uint16_t* src = reinterpret_cast(peer_ptrs[rank_j]); - uint16_t* out = reinterpret_cast(out_ptr); - out[idx] = src[in_idx]; - } -} - -void launch_ulysses_pull( - torch::Tensor peer_ptrs, - torch::Tensor out, - int64_t A, int64_t B, int64_t C, int64_t D, int64_t E, - int P, int my_rank, bool scatter_first -) { - int64_t N_out = A * B * C * D * E; - - int vec_size = 1; - if (E % 8 == 0) vec_size = 8; - else if (E % 4 == 0) vec_size = 4; - else if (E % 2 == 0) vec_size = 2; - - int64_t E_vec = E / vec_size; - int64_t N_out_vec = N_out / vec_size; - - int threads = 256; - int blocks = (N_out_vec + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const uint64_t* ptrs = reinterpret_cast(peer_ptrs.data_ptr()); - void* out_ptr = out.data_ptr(); - - if (vec_size == 8) { - ulysses_pull_kernel<8><<>>( - ptrs, out_ptr, A, B, C, D, E_vec, P, my_rank, scatter_first, N_out_vec); - } else if (vec_size == 4) { - ulysses_pull_kernel<4><<>>( - ptrs, out_ptr, A, B, C, D, E_vec, P, my_rank, scatter_first, N_out_vec); - } else if (vec_size == 2) { - ulysses_pull_kernel<2><<>>( - ptrs, out_ptr, A, B, C, D, E_vec, P, my_rank, scatter_first, N_out_vec); - } else { - ulysses_pull_kernel<1><<>>( - ptrs, out_ptr, A, B, C, D, E_vec, P, my_rank, scatter_first, N_out_vec); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_ulysses_pull", &launch_ulysses_pull, "Ulysses NVLink Pull Kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ulysses_pull_ext", CUDA_SRC) - return _ext - -_step_counter = {} -_symm_cache = {} - -def _power_of_2(n): - if n <= 0: return 0 - return 1 << (n - 1).bit_length() - -def _get_symm_buffer(numel: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): - global _step_counter - group_id = id(group) - step = _step_counter.get(group_id, 0) - _step_counter[group_id] = step + 1 - buf_idx = step % 2 # Double-buffering prevents Read-After-Write hazards without host blocking - - best_key = None - for k in _symm_cache: - k_numel, k_dtype, k_device, k_group, k_idx = k - if k_dtype == dtype and k_device == device and k_group == group and k_idx == buf_idx: - if k_numel >= numel: - if best_key is None or k_numel < best_key[0]: - best_key = k - - if best_key is not None: - buf, hdl, ptrs = _symm_cache[best_key] - return buf[:numel].view(-1), hdl, ptrs - - alloc_numel = max(_power_of_2(numel), 1024 * 1024) - buf = symm_mem.empty(alloc_numel, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_cache[(alloc_numel, dtype, device, group, buf_idx)] = (buf, hdl, ptrs) - return buf[:numel].view(-1), hdl, ptrs - - -@torch.no_grad() -def solution( - x: torch.Tensor, - seq_dim: int, - head_dim: int, - group: Optional[ProcessGroup] = None, - unpadded_dim_size: int = 0, -) -> torch.Tensor: - if group is None or not dist.is_initialized(): - return x - - sp_world = dist.get_world_size(group) - my_rank = dist.get_rank(group) - - if my_rank == 0: - _get_ext() - dist.barrier(group=group) - ext = _get_ext() - - scatter_dim = head_dim - gather_dim = seq_dim - dims = list(x.shape) - - # Pre-calculate 5D flattening constants based on dim order - if scatter_dim < gather_dim: - scatter_first = True - A = math.prod(dims[:scatter_dim]) - B = dims[scatter_dim] - C = math.prod(dims[scatter_dim+1:gather_dim]) - D = dims[gather_dim] - E = math.prod(dims[gather_dim+1:]) - else: - scatter_first = False - A = math.prod(dims[:gather_dim]) - B = dims[gather_dim] - C = math.prod(dims[gather_dim+1:scatter_dim]) - D = dims[scatter_dim] - E = math.prod(dims[scatter_dim+1:]) - - # Prepare symmetric buffer - x_contig = x.contiguous() - numel = x_contig.numel() - symm_buf, hdl, ptrs = _get_symm_buffer(numel, x_contig.dtype, x_contig.device, group) - - # Local contiguous write followed by symmetric memory stream barrier - symm_buf.copy_(x_contig.view(-1)) - hdl.barrier(channel=0) - - # Allocate local output tensor explicitly - out_shape = list(x.shape) - out_shape[seq_dim] *= sp_world - out_shape[head_dim] //= sp_world - out = torch.empty(out_shape, dtype=x.dtype, device=x.device) - - # Dispatch custom PULL NVLink kernel - ext.launch_ulysses_pull( - ptrs, out, A, B, C, D, E, sp_world, my_rank, scatter_first - ) - - # Clean unpadding natively (acts on memory views seamlessly) - if unpadded_dim_size and unpadded_dim_size % sp_world != 0: - padding_size = out.size(seq_dim) - unpadded_dim_size - slc = [slice(None)] * out.dim() - slc[seq_dim] = slice(0, -padding_size) - out = out[tuple(slc)] - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_cuda.py deleted file mode 100755 index 9695283..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/38_ulysses_gather_heads_scatter_seq_cuda.py +++ /dev/null @@ -1,245 +0,0 @@ -""" -Strategy: -1. **Device-Side Communication**: We replace NCCL's host-driven `all_to_all` and PyTorch's multiple `split`/`cat`/`reshape` operations with a single, fused custom CUDA pull kernel. -2. **Symmetric Memory (UVA)**: The local input sequence is staged into a cached `symm_mem` buffer. All ranks then use UVA peer pointers to read their required chunks directly from the peers' symmetric buffers. -3. **Compute-Communication Overlap & Fusion**: Sequence padding is fused directly into the staging copy step, avoiding extra allocations. The pull kernel resolves the scatter-gather multi-dimensional routing on-the-fly and writes the final contiguous output. -4. **Bandwidth Optimization**: The kernel detects the innermost contiguous dimension (`E`) and automatically vectorizes loads/stores (up to 128-bit) to saturate NVLink bandwidth. -""" - -from typing import Optional - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -import torch.distributed._symmetric_memory as symm_mem - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -__global__ void ulysses_pull_kernel( - const __nv_bfloat16* const* __restrict__ peer_ptrs, - __nv_bfloat16* __restrict__ out, - int rank, - int world_size, - int64_t A, - int64_t B, - int64_t C, - int64_t D, - int64_t E, - bool seq_first, - int64_t numel_vec -) { - int64_t vec_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (vec_idx >= numel_vec) return; - - int64_t idx = vec_idx * VEC; - - int64_t a, b_out, c, d_out, e; - int64_t temp = idx; - - e = temp % E; temp /= E; - - int64_t D_out = seq_first ? (D * world_size) : (D / world_size); - d_out = temp % D_out; temp /= D_out; - - c = temp % C; temp /= C; - - int64_t B_out = seq_first ? (B / world_size) : (B * world_size); - b_out = temp % B_out; temp /= B_out; - - a = temp; - - int p; - int64_t b_in, d_in; - - if (seq_first) { - // B is seq_dim, D is head_dim - p = d_out / D; - d_in = d_out % D; - b_in = rank * B_out + b_out; - } else { - // B is head_dim, D is seq_dim - p = b_out / B; - b_in = b_out % B; - d_in = rank * D_out + d_out; - } - - int64_t in_idx = a * (B * C * D * E) + b_in * (C * D * E) + c * (D * E) + d_in * E + e; - const __nv_bfloat16* src_ptr = peer_ptrs[p]; - - if constexpr (VEC == 8) { - *reinterpret_cast(&out[idx]) = *reinterpret_cast(&src_ptr[in_idx]); - } else if constexpr (VEC == 4) { - *reinterpret_cast(&out[idx]) = *reinterpret_cast(&src_ptr[in_idx]); - } else if constexpr (VEC == 2) { - *reinterpret_cast(&out[idx]) = *reinterpret_cast(&src_ptr[in_idx]); - } else { - out[idx] = src_ptr[in_idx]; - } -} - -void launch_ulysses_pull( - torch::Tensor peer_ptrs_tensor, - torch::Tensor out, - int rank, - int world_size, - int64_t A, - int64_t B, - int64_t C, - int64_t D, - int64_t E, - bool seq_first, - int64_t numel -) { - const __nv_bfloat16* const* peer_ptrs = (const __nv_bfloat16* const*)peer_ptrs_tensor.data_ptr(); - __nv_bfloat16* out_ptr = reinterpret_cast<__nv_bfloat16*>(out.data_ptr()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - - if (E % 8 == 0) { - int blocks = (numel / 8 + threads - 1) / threads; - ulysses_pull_kernel<8><<>>( - peer_ptrs, out_ptr, rank, world_size, A, B, C, D, E, seq_first, numel / 8); - } else if (E % 4 == 0) { - int blocks = (numel / 4 + threads - 1) / threads; - ulysses_pull_kernel<4><<>>( - peer_ptrs, out_ptr, rank, world_size, A, B, C, D, E, seq_first, numel / 4); - } else if (E % 2 == 0) { - int blocks = (numel / 2 + threads - 1) / threads; - ulysses_pull_kernel<2><<>>( - peer_ptrs, out_ptr, rank, world_size, A, B, C, D, E, seq_first, numel / 2); - } else { - int blocks = (numel + threads - 1) / threads; - ulysses_pull_kernel<1><<>>( - peer_ptrs, out_ptr, rank, world_size, A, B, C, D, E, seq_first, numel); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_ulysses_pull", &launch_ulysses_pull, "Ulysses gather-scatter pull kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - from utils.cuda_helpers import compile_cuda_extension - _ext = compile_cuda_extension("ulysses_gather_scatter_bf16", CUDA_SRC) - return _ext - - -_symm_cache = {} -def _get_symm_state(shape_tuple, dtype, device, group): - key = (shape_tuple, dtype, device, group) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape_tuple, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group=group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - res = (buf, hdl, ptrs_tensor) - _symm_cache[key] = res - return res - - -@torch.no_grad() -def solution( - x: torch.Tensor, - seq_dim: int, - head_dim: int, - group: Optional[ProcessGroup] = None, -) -> torch.Tensor: - if group is None or dist.get_world_size(group) <= 1: - return x - - seq_dim = seq_dim % x.ndim - head_dim = head_dim % x.ndim - sp_world = dist.get_world_size(group) - dim_size = x.size(seq_dim) - - # Fallback to stock PyTorch for unsupported datatypes - if x.dtype != torch.bfloat16: - if dim_size % sp_world != 0: - padding_size = sp_world - (dim_size % sp_world) - pad_shape = list(x.shape) - pad_shape[seq_dim] = padding_size - pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device) - x = torch.cat([x, pad], dim=seq_dim) - input_list = [t.contiguous() for t in torch.tensor_split(x, sp_world, seq_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(sp_world)] - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=head_dim).contiguous() - - # JIT Compile - if _ext is None: - if dist.get_rank(group) == 0: - _get_ext() - dist.barrier(group) - _get_ext() - - # 1. Determine padded shape - padded_shape = list(x.shape) - needs_padding = (dim_size % sp_world != 0) - if needs_padding: - padded_shape[seq_dim] += sp_world - (dim_size % sp_world) - - # 2. Grab symm_mem staging buffers from cache - buf, hdl, ptrs_tensor = _get_symm_state(tuple(padded_shape), x.dtype, x.device, group) - - # 3. Synchronize stream before overwriting staging memory - hdl.barrier(channel=0) - - # 4. Copy current input to symm_mem buffer (fuse padding logic here) - if needs_padding: - slices = [slice(None)] * x.ndim - slices[seq_dim] = slice(0, dim_size) - buf[tuple(slices)].copy_(x) - - slices_pad = [slice(None)] * x.ndim - slices_pad[seq_dim] = slice(dim_size, None) - buf[tuple(slices_pad)].zero_() - else: - buf.copy_(x) - - # 5. Synchronize stream to ensure all peers have finished writing - hdl.barrier(channel=1) - - # 6. Allocate independent contiguous output - out_shape = list(padded_shape) - out_shape[seq_dim] //= sp_world - out_shape[head_dim] *= sp_world - out = torch.empty(out_shape, dtype=x.dtype, device=x.device) - - # 7. Collapse shapes down to 5D for exact mapping - A, B, C, D, E = 1, 1, 1, 1, 1 - dim1 = min(seq_dim, head_dim) - dim2 = max(seq_dim, head_dim) - - for i in range(dim1): A *= padded_shape[i] - B = padded_shape[dim1] - for i in range(dim1 + 1, dim2): C *= padded_shape[i] - D = padded_shape[dim2] - for i in range(dim2 + 1, len(padded_shape)): E *= padded_shape[i] - - seq_first = (seq_dim < head_dim) - - # 8. Pull from peer UVA pointers directly to local output - _get_ext().launch_ulysses_pull( - ptrs_tensor, - out, - dist.get_rank(group), - sp_world, - A, B, C, D, E, - seq_first, - out.numel() - ) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_cuda.py deleted file mode 100755 index 6cf9d37..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/39_ulysses_gather_seq_scatter_heads_qkv_cuda.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Strategy: -This solution bypasses NCCL to provide an optimal NVLink P2P sequence parallel all-to-all schedule. We exploit symmetric memory to directly map the scattered-head and gathered-sequence chunking onto a single fused CUDA kernel. By computing the multi-dimensional stride mapping purely mathematically within the kernel, each device pulls exactly its required `(seq, head)` sub-chunks directly from peers' HBM without any intermediate slicing or concatenations. This maximizes bidirectional NVLink bandwidth, eliminates PyTorch op overhead, and fuses what would normally be several reshapes and chunking operations into one single-launch vectorized pull. We also implement the reverse mapping for a zero-overhead backward pass. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Any -from torch.distributed import ProcessGroup -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -template -__global__ void all2all_pull_kernel( - const uint64_t* __restrict__ symm_ptrs, - T* __restrict__ out, - int64_t W, - int64_t me, - int64_t SM, - int64_t h_vecs, - int64_t num_segments, - int64_t total_vecs, - bool is_backward -) { - int64_t v_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (v_idx >= total_vecs) return; - - int64_t chunk_idx = v_idx / h_vecs; - int64_t v_offset = v_idx % h_vecs; - - int64_t j = chunk_idx / num_segments; - int64_t segment_idx = chunk_idx % num_segments; - - int64_t a = segment_idx / SM; - int64_t rem = segment_idx % SM; - - int64_t dest_chunk_idx; - int64_t src_chunk_idx; - - if (!is_backward) { - dest_chunk_idx = (a * W + j) * SM + rem; - src_chunk_idx = segment_idx * W + me; - } else { - dest_chunk_idx = segment_idx * W + j; - src_chunk_idx = (a * W + me) * SM + rem; - } - - int64_t dest_idx = dest_chunk_idx * h_vecs + v_offset; - int64_t src_idx = src_chunk_idx * h_vecs + v_offset; - - using VecType = typename std::aligned_storage::type; - - const T* src_ptr = reinterpret_cast(symm_ptrs[j]); - const VecType* src_vec = reinterpret_cast(src_ptr); - VecType* out_vec = reinterpret_cast(out); - - out_vec[dest_idx] = src_vec[src_idx]; -} - -void launch_all2all_pull( - torch::Tensor symm_ptrs_tensor, - torch::Tensor out, - int64_t W, - int64_t me, - int64_t SM, - int64_t h, - int64_t num_segments, - bool is_backward -) { - size_t el_size = out.element_size(); - int64_t vec_elements = 1; - uintptr_t ptr_val = reinterpret_cast(out.data_ptr()); - - if (h % (16 / el_size) == 0 && ptr_val % 16 == 0) { - vec_elements = 16 / el_size; - } else if (h % (8 / el_size) == 0 && ptr_val % 8 == 0) { - vec_elements = 8 / el_size; - } else if (h % (4 / el_size) == 0 && ptr_val % 4 == 0) { - vec_elements = 4 / el_size; - } - - int64_t h_vecs = h / vec_elements; - int64_t total_vecs = num_segments * W * h_vecs; - - int64_t threads = 256; - int blocks = (int)((total_vecs + threads - 1) / threads); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const uint64_t* symm_ptrs = reinterpret_cast(symm_ptrs_tensor.data_ptr()); - - #define DISPATCH_KERNEL(T, V) \ - all2all_pull_kernel<<>>( \ - symm_ptrs, out.data_ptr(), W, me, SM, h_vecs, num_segments, total_vecs, is_backward) - - if (out.dtype() == torch::kBFloat16) { - if (vec_elements == 8) { DISPATCH_KERNEL(at::BFloat16, 8); } - else if (vec_elements == 4) { DISPATCH_KERNEL(at::BFloat16, 4); } - else if (vec_elements == 2) { DISPATCH_KERNEL(at::BFloat16, 2); } - else { DISPATCH_KERNEL(at::BFloat16, 1); } - } else if (out.dtype() == torch::kFloat16) { - if (vec_elements == 8) { DISPATCH_KERNEL(at::Half, 8); } - else if (vec_elements == 4) { DISPATCH_KERNEL(at::Half, 4); } - else if (vec_elements == 2) { DISPATCH_KERNEL(at::Half, 2); } - else { DISPATCH_KERNEL(at::Half, 1); } - } else if (out.dtype() == torch::kFloat32) { - if (vec_elements == 4) { DISPATCH_KERNEL(float, 4); } - else if (vec_elements == 2) { DISPATCH_KERNEL(float, 2); } - else { DISPATCH_KERNEL(float, 1); } - } - - #undef DISPATCH_KERNEL - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_all2all_pull", &launch_all2all_pull, "Symmetric Memory NVLink All-to-All Pull"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ulysses_fused_qkv_all2all_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(shape, dtype, device, group): - key = (tuple(shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return _symm_cache[key] - -class FusedQKVAllToAllSymm(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv_tensor, seq_dim, group, unpadded_dim_size, restore_shape): - W = dist.get_world_size(group) - ctx.seq_dim = seq_dim - ctx.group = group - ctx.unpadded_dim_size = unpadded_dim_size - ctx.restore_shape = restore_shape - ctx.orig_shape = qkv_tensor.shape - ctx.W = W - - if W == 1: - return qkv_tensor - - orig_shape = qkv_tensor.shape - qkv_proj_dim = orig_shape[-1] - - bef_all2all_shape = list(orig_shape) - bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] - - qkv_tensor = qkv_tensor.contiguous() - - buf, hdl, ptrs_tensor = _get_symm_state(bef_all2all_shape, qkv_tensor.dtype, qkv_tensor.device, group) - - buf.view(-1).copy_(qkv_tensor.view(-1)) - hdl.barrier(channel=0) - - A = 1 - for i in range(seq_dim): - A *= bef_all2all_shape[i] - S = bef_all2all_shape[seq_dim] - M = 1 - for i in range(seq_dim + 1, len(bef_all2all_shape) - 1): - M *= bef_all2all_shape[i] - H = bef_all2all_shape[-1] - - h = H // W - num_segments = A * S * M - - gathered_shape = list(bef_all2all_shape) - gathered_shape[seq_dim] = W * S - gathered_shape[-1] = h - - out = torch.empty(gathered_shape, dtype=qkv_tensor.dtype, device=qkv_tensor.device) - me = dist.get_rank(group) - - _get_ext().launch_all2all_pull(ptrs_tensor, out, W, me, S * M, h, num_segments, False) - - hdl.barrier(channel=0) - - ctx.S_M = S * M - ctx.h = h - ctx.num_segments = num_segments - ctx.bef_all2all_shape = bef_all2all_shape - ctx.gathered_shape = gathered_shape - - final_out = out - - if restore_shape: - out_shape = list(orig_shape) - out_shape[seq_dim] *= W - out_shape[-1] = qkv_proj_dim // W - final_out = final_out.view(out_shape) - - if unpadded_dim_size and unpadded_dim_size % W != 0: - padding_size = final_out.size(seq_dim) - unpadded_dim_size - slc = [slice(None)] * final_out.dim() - slc[seq_dim] = slice(0, -padding_size) - final_out = final_out[tuple(slc)] - - return final_out - - @staticmethod - def backward(ctx, grad_output): - W = ctx.W - if W == 1: - return grad_output, None, None, None, None - - grad_output = grad_output.contiguous() - - if ctx.unpadded_dim_size and ctx.unpadded_dim_size % W != 0: - padding_size = ctx.gathered_shape[ctx.seq_dim] - ctx.unpadded_dim_size - shape = list(grad_output.shape) - shape[ctx.seq_dim] = padding_size - pad = torch.zeros(shape, dtype=grad_output.dtype, device=grad_output.device) - grad_output = torch.cat([grad_output, pad], dim=ctx.seq_dim) - - grad_output = grad_output.view(ctx.gathered_shape) - - buf, hdl, ptrs_tensor = _get_symm_state(ctx.gathered_shape, grad_output.dtype, grad_output.device, ctx.group) - buf.view(-1).copy_(grad_output.view(-1)) - hdl.barrier(channel=0) - - grad_input = torch.empty(ctx.bef_all2all_shape, dtype=grad_output.dtype, device=grad_output.device) - me = dist.get_rank(ctx.group) - - _get_ext().launch_all2all_pull(ptrs_tensor, grad_input, W, me, ctx.S_M, ctx.h, ctx.num_segments, True) - - hdl.barrier(channel=0) - - grad_input = grad_input.view(ctx.orig_shape) - return grad_input, None, None, None, None - - -def solution( - qkv_tensor: torch.Tensor, - seq_dim: int, - group: Optional[ProcessGroup] = None, - unpadded_dim_size: Optional[int] = None, - restore_shape: bool = True, -) -> torch.Tensor: - """ - Per-rank inputs: - qkv_tensor: fused QKV [..., qkv_proj_dim]; last dim divisible by 3 and world_size. - seq_dim: sequence dimension (gather dim for all_to_all). - group: SP process group (default world). - unpadded_dim_size: if set and not divisible by world_size, unpad output. - restore_shape: if True, output shape matches input ndim with seq_dim and last dim resized. - - Returns (per rank): - output: tensor after fused QKV all_to_all (and optional reshape/unpad). - """ - group = group or (dist.group.WORLD if dist.is_initialized() else None) - if not group: - return qkv_tensor - - return FusedQKVAllToAllSymm.apply( - qkv_tensor, - seq_dim, - group, - unpadded_dim_size or 0, - restore_shape - ) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_cuda.py deleted file mode 100755 index c520c39..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/3_broadcast_cuda.py +++ /dev/null @@ -1,181 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void broadcast_multimem_kernel_padded( - const void* __restrict__ src_ptr, - uint64_t multicast_base, - int64_t n_padded_bytes -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - int64_t n_vec = n_padded_bytes / 16; - const uint4* src_vec = reinterpret_cast(src_ptr); - - for (int64_t i = idx; i < n_vec; i += stride) { - uint4 val = src_vec[i]; - uint64_t* mc_addr = reinterpret_cast(multicast_base) + i * 2; - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(mc_addr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) - : "memory"); - } -} - -__global__ void pull_broadcast_kernel_padded( - const void* __restrict__ src_ptr, - void* __restrict__ dst_ptr, - int64_t n_padded_bytes -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - int64_t n_vec = n_padded_bytes / 16; - const uint4* src_vec = reinterpret_cast(src_ptr); - uint4* dst_vec = reinterpret_cast(dst_ptr); - - for (int64_t i = idx; i < n_vec; i += stride) { - dst_vec[i] = src_vec[i]; - } -} - -void launch_broadcast_multimem( - int64_t src_ptr, - int64_t multicast_ptr, - int64_t n_padded_bytes -) { - int threads = 512; - int blocks = std::min((int)((n_padded_bytes / 16 + threads - 1) / threads), 1024); - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - broadcast_multimem_kernel_padded<<>>( - reinterpret_cast(static_cast(src_ptr)), - static_cast(multicast_ptr), - n_padded_bytes - ); -} - -void launch_pull_broadcast( - int64_t src_ptr, - int64_t dst_ptr, - int64_t n_padded_bytes -) { - int threads = 512; - int blocks = std::min((int)((n_padded_bytes / 16 + threads - 1) / threads), 1024); - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - pull_broadcast_kernel_padded<<>>( - reinterpret_cast(static_cast(src_ptr)), - reinterpret_cast(static_cast(dst_ptr)), - n_padded_bytes - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_broadcast_multimem", &launch_broadcast_multimem, "Multimem broadcast kernel"); - m.def("launch_pull_broadcast", &launch_pull_broadcast, "Pull broadcast kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("broadcast_multimem_ext", CUDA_SRC) - return _ext - - -_symm_cache = {} - -def _get_symm_state(n_bytes: int, dtype: torch.dtype, device: torch.device): - key = (n_bytes, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty((n_bytes,), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - _symm_cache[key] = (buf, hdl) - return buf, hdl - - -@torch.no_grad() -def solution(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: - """ - Symmetric memory Multimem broadcast. Replaces NCCL broadcast with custom - multimem.st PTX broadcast (Hopper) or fast UVA pull kernel. - """ - if not dist.is_initialized(): - return tensor.clone() - - rank = dist.get_rank() - n_bytes = tensor.numel() * tensor.element_size() - - if n_bytes == 0: - return tensor.clone() if rank == src else torch.empty_like(tensor) - - # Pad buffer size to next 16-byte multiple to allow 100% vectorized 128-bit memory ops - padded_bytes = (n_bytes + 15) // 16 * 16 - - # JIT compile safely - if rank == 0: - _get_ext() - dist.barrier() - _get_ext() - - buf, hdl = _get_symm_state(padded_bytes, torch.uint8, tensor.device) - - # 1. Sync: ensure previous operations on the cached `buf` are globally done. - hdl.barrier(channel=0) - - # 2. Source rank stages the local payload into its symmetric, 16-byte aligned buffer. - if rank == src: - buf_view = buf[:n_bytes].view(tensor.dtype).view(tensor.shape) - buf_view.copy_(tensor) - - # 3. Broadcast data directly over symmetric mappings. - if hdl.multicast_ptr: - if rank == src: - _get_ext().launch_broadcast_multimem( - buf.data_ptr(), - hdl.multicast_ptr, - padded_bytes - ) - # Device sync: Ensure NVSwitch multimem stores land globally across ranks. - hdl.barrier(channel=0) - else: - # Fallback device sync: Ensure src's initial staging memory-copy is globally visible. - hdl.barrier(channel=0) - - if rank != src: - src_buf_ptr = int(hdl.buffer_ptrs[src]) - _get_ext().launch_pull_broadcast( - src_buf_ptr, - buf.data_ptr(), - padded_bytes - ) - - # Device sync: Ensure pull kernels complete on receivers. - hdl.barrier(channel=0) - - # 4. Expose the populated data out. - if rank == src: - out = tensor.clone() - else: - out = torch.empty_like(tensor) - out_view = buf[:n_bytes].view(tensor.dtype).view(tensor.shape) - out.copy_(out_view) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_cuda.py deleted file mode 100755 index 70420b1..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/40_ulysses_attention_e2e_cuda.py +++ /dev/null @@ -1,338 +0,0 @@ -import torch -import torch.nn.functional as F -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// 1. QKV Scatter Kernel -// Reads local chunked QKV [B, chunk_len, 3, num_heads, head_dim] -// Writes directly to peers' gathered buffer [3, B, S_full, num_heads_local, head_dim] -// --------------------------------------------------------------------------- -__global__ void qkv_alltoall_kernel_flat( - const uint4* __restrict__ qkv, - const uint64_t* __restrict__ dest_ptrs, - int B, int chunk_len, int num_heads, int head_dim_vec, - int rank, int world_size, int S_local, int start_s, - int64_t total_vecs -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_vecs) return; - - int d_v = idx % head_dim_vec; - int64_t tmp = idx / head_dim_vec; - int h = tmp % num_heads; - tmp /= num_heads; - int qkv_idx = tmp % 3; - tmp /= 3; - int s = tmp % chunk_len; - int b = tmp / chunk_len; - - int num_heads_local = num_heads / world_size; - int p = h / num_heads_local; - int h_dst = h % num_heads_local; - int S_full = S_local * world_size; - int s_dst = rank * S_local + start_s + s; - - // dest shape: [3, B, S_full, num_heads_local, head_dim_vec] - int64_t dest_offset = ((((int64_t)(qkv_idx * B + b) * S_full + s_dst) * num_heads_local) + h_dst) * head_dim_vec + d_v; - - uint4* peer_dest = (uint4*)dest_ptrs[p]; - peer_dest[dest_offset] = qkv[idx]; -} - -// --------------------------------------------------------------------------- -// 2. Attention Output Scatter Kernel -// Reads local attention output [B, S_full, num_heads_local, head_dim] -// Writes directly to peers' buffer [B, S_local, num_heads, head_dim] -// --------------------------------------------------------------------------- -__global__ void attn_out_alltoall_kernel_flat( - const uint4* __restrict__ attn_out, - const uint64_t* __restrict__ dest_ptrs, - int B, int S_full, int num_heads_local, int head_dim_vec, - int rank, int world_size, - int S_local, int start_s_dst, int chunk_len, int64_t total_vecs -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_vecs) return; - - int d_v = idx % head_dim_vec; - int64_t tmp = idx / head_dim_vec; - int h_local = tmp % num_heads_local; - tmp /= num_heads_local; - int s_in_chunk = tmp % chunk_len; - tmp /= chunk_len; - int p = tmp % world_size; - int b = tmp / world_size; - - int s_src = p * S_local + start_s_dst + s_in_chunk; - int64_t src_offset = (((int64_t)(b * S_full + s_src) * num_heads_local) + h_local) * head_dim_vec + d_v; - - int s_dst = start_s_dst + s_in_chunk; - int h_dst = rank * num_heads_local + h_local; - int num_heads = num_heads_local * world_size; - - // dest shape: [B, S_local, num_heads, head_dim_vec] - int64_t dest_offset = (((int64_t)(b * S_local + s_dst) * num_heads) + h_dst) * head_dim_vec + d_v; - - uint4* peer_dest = (uint4*)dest_ptrs[p]; - peer_dest[dest_offset] = attn_out[src_offset]; -} - -// --------------------------------------------------------------------------- -// 3. Device-Side Synchronization Kernels -// --------------------------------------------------------------------------- -__global__ void signal_peers_kernel_relaxed( - const uint64_t* __restrict__ signal_ptrs, - int rank, int c, int world_size -) { - int p = threadIdx.x; - if (p < world_size) { - volatile int* peer_signal = (volatile int*)signal_ptrs[p]; - peer_signal[c * world_size + rank] = 1; - } -} - -__global__ void wait_signal_kernel_relaxed( - volatile int* __restrict__ my_signal, - int c, int world_size -) { - int p = threadIdx.x; - if (p < world_size) { - while (my_signal[c * world_size + p] == 0) { - // Spin waiting for peer 'p' to flag completion for chunk 'c' - } - } -} - -// --------------------------------------------------------------------------- -// C++ Bindings -// --------------------------------------------------------------------------- -void launch_qkv_alltoall( - torch::Tensor qkv, - torch::Tensor dest_ptrs, - int B, int chunk_len, int num_heads, int head_dim, - int rank, int world_size, int S_local, int start_s -) { - TORCH_CHECK(head_dim % 8 == 0, "head_dim must be a multiple of 8 for BF16 vectorization"); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int head_dim_vec = head_dim / 8; - int64_t total_vecs = (int64_t)B * chunk_len * 3 * num_heads * head_dim_vec; - int threads = 256; - int blocks = (total_vecs + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - qkv_alltoall_kernel_flat<<>>( - (const uint4*)qkv.data_ptr(), - (const uint64_t*)dest_ptrs.data_ptr(), - B, chunk_len, num_heads, head_dim_vec, - rank, world_size, S_local, start_s, - total_vecs - ); -} - -void launch_attn_out_alltoall( - torch::Tensor attn_out, - torch::Tensor dest_ptrs, - int B, int S_full, int num_heads_local, int head_dim, - int rank, int world_size, int S_local, int start_s_dst, int chunk_len -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int head_dim_vec = head_dim / 8; - int64_t total_vecs = (int64_t)B * world_size * chunk_len * num_heads_local * head_dim_vec; - int threads = 256; - int blocks = (total_vecs + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - attn_out_alltoall_kernel_flat<<>>( - (const uint4*)attn_out.data_ptr(), - (const uint64_t*)dest_ptrs.data_ptr(), - B, S_full, num_heads_local, head_dim_vec, - rank, world_size, S_local, start_s_dst, chunk_len, - total_vecs - ); -} - -void launch_signal_peers(torch::Tensor signal_ptrs, int rank, int c, int world_size) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - signal_peers_kernel_relaxed<<<1, 32, 0, stream>>>((const uint64_t*)signal_ptrs.data_ptr(), rank, c, world_size); -} - -void launch_wait_signal(torch::Tensor my_signal, int c, int world_size) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - wait_signal_kernel_relaxed<<<1, 32, 0, stream>>>((volatile int*)my_signal.data_ptr(), c, world_size); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_qkv_alltoall", &launch_qkv_alltoall); - m.def("launch_attn_out_alltoall", &launch_attn_out_alltoall); - m.def("launch_signal_peers", &launch_signal_peers); - m.def("launch_wait_signal", &launch_wait_signal); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ulysses_attn_overlap_ext", CUDA_SRC) - return _ext - -_resource_cache = {} -def _get_resources(B, S_local, S_full, num_heads, num_heads_local, head_dim, world_size, dtype, device): - key = (B, S_local, S_full, num_heads, num_heads_local, head_dim, world_size, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - qkv_gathered = symm_mem.empty((3, B, S_full, num_heads_local, head_dim), dtype=dtype, device=device) - attn_gathered = symm_mem.empty((B, S_local, num_heads, head_dim), dtype=dtype, device=device) - signal_pad = symm_mem.empty((4, world_size), dtype=torch.int32, device=device) # Supports up to 4 chunks - - hdl_qkv = symm_mem.rendezvous(qkv_gathered, dist.group.WORLD) - hdl_attn = symm_mem.rendezvous(attn_gathered, dist.group.WORLD) - hdl_signal = symm_mem.rendezvous(signal_pad, dist.group.WORLD) - - dest_ptrs_qkv = torch.tensor(hdl_qkv.buffer_ptrs, dtype=torch.int64, device=device) - dest_ptrs_attn = torch.tensor(hdl_attn.buffer_ptrs, dtype=torch.int64, device=device) - dest_ptrs_signal = torch.tensor(hdl_signal.buffer_ptrs, dtype=torch.int64, device=device) - - res = (qkv_gathered, attn_gathered, signal_pad, dest_ptrs_qkv, dest_ptrs_attn, dest_ptrs_signal) - _resource_cache[key] = res - return res - -_streams = None -def _get_streams(n): - global _streams - if _streams is None: - _streams = [torch.cuda.Stream() for _ in range(4)] - return _streams[:n] - -def _local_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float, - causal: bool = False, -) -> torch.Tensor: - """Minimal scaled dot-product attention logic. Kept identical for exact reference parity.""" - scores = torch.matmul(q, k.transpose(-2, -1)) * scale - if causal and q.size(1) > 1: - S = scores.size(-1) - causal_mask = torch.triu( - torch.ones(S, S, device=scores.device, dtype=torch.bool), - diagonal=1, - ) - scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) - attn = F.softmax(scores, dim=-1) - return torch.matmul(attn, v) - -@torch.no_grad() -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, - num_heads: int = 8, - causal: bool = False, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - - B, S_local, H = hidden_states.shape - head_dim = (w_qkv.shape[0] // 3) // num_heads - - if world_size == 1: - qkv = F.linear(hidden_states, w_qkv) - qkv = qkv.view(B, S_local, 3, num_heads, head_dim) - q, k, v = qkv.unbind(2) - scale = head_dim**-0.5 - attn_out = _local_attention(q, k, v, scale, causal=causal) - out = attn_out.reshape(B, S_local, -1) - return F.linear(out, w_o) - - ext = _get_ext() - rank = dist.get_rank(group) - num_heads_local = num_heads // world_size - S_full = S_local * world_size - - # Establish UVA pointers and unified target buffers - qkv_gathered, attn_gathered, signal_pad, dest_ptrs_qkv, dest_ptrs_attn, dest_ptrs_signal = _get_resources( - B, S_local, S_full, num_heads, num_heads_local, head_dim, world_size, hidden_states.dtype, hidden_states.device - ) - - num_chunks = 2 if S_local >= 2 else 1 - chunk_size = S_local // num_chunks - chunks = [] - for i in range(num_chunks): - start = i * chunk_size - end = S_local if i == num_chunks - 1 else (i + 1) * chunk_size - chunks.append((start, end - start)) - - streams = _get_streams(num_chunks) - current_stream = torch.cuda.current_stream() - - for s in streams: - s.wait_stream(current_stream) - - # 1. Pipeline QKV Matmul and P2P Scatter kernel - for c, (start_s, chunk_len) in enumerate(chunks): - with torch.cuda.stream(streams[c % len(streams)]): - hs_chunk = hidden_states[:, start_s:start_s+chunk_len, :] - qkv_chunk = F.linear(hs_chunk, w_qkv) - qkv_chunk = qkv_chunk.view(B, chunk_len, 3, num_heads, head_dim) - ext.launch_qkv_alltoall( - qkv_chunk, dest_ptrs_qkv, - B, chunk_len, num_heads, head_dim, - rank, world_size, S_local, start_s - ) - - for s in streams: - current_stream.wait_stream(s) - - # Await global reception of all query/key/value slices prior to attention compute - dist.barrier(group=group) - - # 2. Complete Local Attention computation locally - q = qkv_gathered[0] - k = qkv_gathered[1] - v = qkv_gathered[2] - - scale = head_dim**-0.5 - attn_out = _local_attention(q, k, v, scale, causal=causal) - - # Reset signal pads before triggering the final P2P stage - signal_pad.zero_() - dist.barrier(group=group) - - out = torch.empty(B, S_local, w_o.shape[0], device=hidden_states.device, dtype=hidden_states.dtype) - - for s in streams: - s.wait_stream(current_stream) - - # 3. Pipeline Attn Scatter and Final Projection using Device Spinlocks - for c, (start_s, chunk_len) in enumerate(chunks): - with torch.cuda.stream(streams[c % len(streams)]): - ext.launch_attn_out_alltoall( - attn_out, dest_ptrs_attn, - B, S_full, num_heads_local, head_dim, - rank, world_size, S_local, start_s, chunk_len - ) - ext.launch_signal_peers(dest_ptrs_signal, rank, c, world_size) - ext.launch_wait_signal(signal_pad, c, world_size) - - attn_gathered_chunk = attn_gathered[:, start_s:start_s+chunk_len, :, :].reshape(B, chunk_len, -1) - out_chunk = F.linear(attn_gathered_chunk, w_o) - out[:, start_s:start_s+chunk_len, :] = out_chunk - - for s in streams: - current_stream.wait_stream(s) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/41_ddp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/41_ddp_cuda.py deleted file mode 100755 index d887d9c..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/41_ddp_cuda.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -Strategy: -1. **Persistent Symmetric Memory**: We allocate a single symmetric memory buffer (`symm_mem`) on each rank that holds the flattened parameters, Adam moments (`exp_avg`, `exp_avg_sq`), and gradients. -2. **Zero-Copy Parameter Broadcast**: On step 1, Rank 0 writes its initial state to the symmetric buffer, and peers use a custom P2P pull kernel to fetch it. On subsequent steps, we detect if the input tensors are already views of our persistent buffer. If so, we *completely bypass* the broadcast, achieving zero-overhead persistence. -3. **Fused All-Reduce and Adam**: Instead of executing `dist.all_reduce` followed by stock PyTorch Adam operations, we launch a single custom CUDA kernel. It performs an all-to-all P2P read of the gradients directly from peers' symmetric buffers, averages them, and computes the Adam step immediately. This maximizes memory bandwidth by fusing cross-device communication and element-wise computation into a single pass. -4. **Minimal Stock PyTorch**: By keeping the authoritative state continuously in device memory and using our custom fused kernel, we eliminate all opaque collectives and intermediate tensor allocations on the performance-critical path. -""" - -import math -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -struct CudaTypeTraits; - -template <> -struct CudaTypeTraits { - static __device__ __forceinline__ float to_float(float x) { return x; } - static __device__ __forceinline__ float from_float(float x) { return x; } -}; - -template <> -struct CudaTypeTraits<__nv_bfloat16> { - static __device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } - static __device__ __forceinline__ __nv_bfloat16 from_float(float x) { return __float2bfloat16(x); } -}; - -template -__global__ void pull_broadcast_kernel( - const T* __restrict__ src_buf, - T* __restrict__ local_buf, - int64_t total_elements -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < total_elements; idx += (int64_t)gridDim.x * blockDim.x) { - local_buf[idx] = src_buf[idx]; - } -} - -template -__global__ void fused_allreduce_adam_kernel( - const long long* __restrict__ peer_ptrs, - T* __restrict__ local_buf, - int world_size, - int64_t n, - float lr, - float beta1, - float beta2, - float eps, - float bc1, - float bc2 -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t m_offset = n; - int64_t v_offset = 2 * n; - int64_t g_offset = 3 * n; - - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum_g = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const T* peer_buf = (const T*)peer_ptrs[r]; - sum_g += CudaTypeTraits::to_float(peer_buf[g_offset + idx]); - } - float g = sum_g / world_size; - - float p = CudaTypeTraits::to_float(local_buf[idx]); - float m = CudaTypeTraits::to_float(local_buf[m_offset + idx]); - float v = CudaTypeTraits::to_float(local_buf[v_offset + idx]); - - m = m * beta1 + g * (1.0f - beta1); - v = v * beta2 + g * g * (1.0f - beta2); - - float m_hat = m / bc1; - float v_hat = v / bc2; - float denom = sqrtf(v_hat) + eps; - - p = p - lr * (m_hat / denom); - - local_buf[idx] = CudaTypeTraits::from_float(p); - local_buf[m_offset + idx] = CudaTypeTraits::from_float(m); - local_buf[v_offset + idx] = CudaTypeTraits::from_float(v); - } -} - -void pull_broadcast( - int64_t remote_ptr, - torch::Tensor local_buf, - int64_t total_elements, - int dtype_enum -) { - int threads = 512; - int blocks = (total_elements + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (dtype_enum == 0) { - const __nv_bfloat16* src = reinterpret_cast(remote_ptr); - __nv_bfloat16* dst = reinterpret_cast<__nv_bfloat16*>(local_buf.data_ptr()); - pull_broadcast_kernel<__nv_bfloat16><<>>(src, dst, total_elements); - } else { - const float* src = reinterpret_cast(remote_ptr); - float* dst = local_buf.data_ptr(); - pull_broadcast_kernel<<>>(src, dst, total_elements); - } -} - -void fused_allreduce_adam( - torch::Tensor ptrs_tensor, - torch::Tensor local_buf, - int world_size, - int64_t n, - float lr, - float beta1, - float beta2, - float eps, - float bc1, - float bc2, - int dtype_enum -) { - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const long long* peer_ptrs = (const long long*)ptrs_tensor.data_ptr(); - - if (dtype_enum == 0) { - __nv_bfloat16* local = reinterpret_cast<__nv_bfloat16*>(local_buf.data_ptr()); - fused_allreduce_adam_kernel<__nv_bfloat16><<>>( - peer_ptrs, local, world_size, n, lr, beta1, beta2, eps, bc1, bc2 - ); - } else { - float* local = local_buf.data_ptr(); - fused_allreduce_adam_kernel<<>>( - peer_ptrs, local, world_size, n, lr, beta1, beta2, eps, bc1, bc2 - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("pull_broadcast", &pull_broadcast, "Pull broadcast kernel"); - m.def("fused_allreduce_adam", &fused_allreduce_adam, "Fused allreduce and adam kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_ddp_adam_ext", CUDA_SRC) - return _ext - - -_cache = {} - -def get_symm_state(n: int, dtype: torch.dtype, device: torch.device): - key = (n, dtype, device) - if key in _cache: - return _cache[key] - - # Allocate a single symmetric buffer containing: [params, exp_avg, exp_avg_sq, grads] - total_elements = 4 * n - buf = symm_mem.empty(total_elements, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _cache[key] = (buf, hdl, ptrs_tensor) - return _cache[key] - - -def solution( - X_local: Tensor, y_local: Tensor, - W1: Tensor, b1: Tensor, W2: Tensor, b2: Tensor, - exp_avg_W1: Tensor, exp_avg_b1: Tensor, exp_avg_W2: Tensor, exp_avg_b2: Tensor, - exp_avg_sq_W1: Tensor, exp_avg_sq_b1: Tensor, exp_avg_sq_W2: Tensor, exp_avg_sq_b2: Tensor, - lr: float, beta1: float, beta2: float, eps: float, step: int, -) -> tuple[Tensor, ...]: - - world_size = dist.get_world_size() - rank = dist.get_rank() - - params_in = [W1, b1, W2, b2] - exp_avg_in = [exp_avg_W1, exp_avg_b1, exp_avg_W2, exp_avg_b2] - exp_avg_sq_in = [exp_avg_sq_W1, exp_avg_sq_b1, exp_avg_sq_W2, exp_avg_sq_b2] - - n = sum(t.numel() for t in params_in) - dtype = W1.dtype - device = torch.cuda.current_device() - - # Initialize extension safely on rank 0 first - if rank == 0: - _get_ext() - dist.barrier() - - buf, hdl, ptrs_tensor = get_symm_state(n, dtype, device) - dtype_enum = 0 if dtype == torch.bfloat16 else 1 - - # If the caller passed the exact views we returned previously, state is already sync'd - is_cached = (params_in[0].data_ptr() == buf.data_ptr()) - - if not is_cached: - # Full sync initialization needed - hdl.barrier(channel=0) - if rank == 0: - flat_all = _flatten_dense_tensors(params_in + exp_avg_in + exp_avg_sq_in) - buf[:3*n].copy_(flat_all) - - hdl.barrier(channel=0) - if rank != 0: - remote_ptr = int(hdl.buffer_ptrs[0]) - _get_ext().pull_broadcast(remote_ptr, buf, 3 * n, dtype_enum) - - hdl.barrier(channel=0) - - # Create views to flush any legacy gradient state from PyTorch's AD engine - params, exp_avg, exp_avg_sq = [], [], [] - - offset = 0 - for t in params_in: - params.append(buf[offset : offset + t.numel()].view(t.shape).detach().requires_grad_(True)) - offset += t.numel() - - offset = n - for t in exp_avg_in: - exp_avg.append(buf[offset : offset + t.numel()].view(t.shape)) - offset += t.numel() - - offset = 2 * n - for t in exp_avg_sq_in: - exp_avg_sq.append(buf[offset : offset + t.numel()].view(t.shape)) - offset += t.numel() - - # Standard PyTorch local computation - h = F.relu(F.linear(X_local, params[0], params[1])) - out = F.linear(h, params[2], params[3]) - loss = F.mse_loss(out, y_local) - loss.backward() - - # Flat-copy gradients to the symmetric memory gradients buffer block - grads = [p.grad for p in params] - flat_grad = _flatten_dense_tensors(grads) - buf[3*n:].copy_(flat_grad) - - hdl.barrier(channel=0) - - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - - # Peer-pointers fused all-reduce (SUM) and Adam update directly on parameters - _get_ext().fused_allreduce_adam( - ptrs_tensor, buf, world_size, n, lr, beta1, beta2, eps, bc1, bc2, dtype_enum - ) - - hdl.barrier(channel=0) - - # Returns the updated views perfectly matching expected reference outputs - return tuple(params + exp_avg + exp_avg_sq) - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/42_zero1_optimizer_shard_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/42_zero1_optimizer_shard_cuda.py deleted file mode 100755 index dc387c1..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/42_zero1_optimizer_shard_cuda.py +++ /dev/null @@ -1,422 +0,0 @@ -""" -Strategy: -1. **Symmetric Memory & Persistent Buffers**: We allocate `flat_p_buf` (parameters) and `flat_g_buf` (gradients) in PyTorch symmetric memory, caching them across steps to eliminate PyTorch's native memory reallocation and process group overheads. -2. **Fused Reduce-Scatter + Adam + All-Gather**: Instead of a full `dist.all_reduce` followed by slicing and a final `all_gather`, we completely fuse the communication and optimizer step. Each rank loads only its partition of gradients. -3. **NVSwitch Multimem Hardware Acceleration**: If operating in BF16, the kernel uses `multimem.ld_reduce` to instantly sum gradient partitions across the NVLink switch directly into registers, executes the Adam step, and immediately broadcasts the updated weights to all peers simultaneously via `multimem.st`. -4. **UVA Fallback for Remainder/Dtypes**: Handles sizes non-divisible by 8 or FP32 inputs via direct device-to-device peer memory accesses. -5. **Device-Side Sync**: Synchronization between forward/backward passes and the fused optimizer is handled using fast device-side barriers (`hdl.barrier()`). -""" - -from __future__ import annotations - -import math - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__device__ __forceinline__ float2 do_adam_bf16x2_tmpl( - uint32_t g_sum_u32, uint32_t w_u32, - moment_t* m_ptr, moment_t* v_ptr, - float scale_g, float lr, float beta1, float beta2, float eps, - float bc1, float bc2 -) { - float2 g = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&g_sum_u32)); - float2 w = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&w_u32)); - - g.x *= scale_g; - g.y *= scale_g; - - float m0 = (float)m_ptr[0]; - float m1 = (float)m_ptr[1]; - float v0 = (float)v_ptr[0]; - float v1 = (float)v_ptr[1]; - - m0 = beta1 * m0 + (1.0f - beta1) * g.x; - m1 = beta1 * m1 + (1.0f - beta1) * g.y; - - v0 = beta2 * v0 + (1.0f - beta2) * g.x * g.x; - v1 = beta2 * v1 + (1.0f - beta2) * g.y * g.y; - - m_ptr[0] = (moment_t)m0; - m_ptr[1] = (moment_t)m1; - v_ptr[0] = (moment_t)v0; - v_ptr[1] = (moment_t)v1; - - float m_hat0 = m0 / bc1; - float m_hat1 = m1 / bc1; - - float v_hat0 = v0 / bc2; - float v_hat1 = v1 / bc2; - - w.x -= lr * m_hat0 / (sqrtf(v_hat0) + eps); - w.y -= lr * m_hat1 / (sqrtf(v_hat1) + eps); - - return w; -} - -__device__ __forceinline__ uint32_t pack_bf16x2(float2 w) { -#if __CUDA_ARCH__ >= 800 - __nv_bfloat162 res = __floats2bfloat162_rn(w.x, w.y); - return *reinterpret_cast(&res); -#else - return 0; -#endif -} - -template -__global__ void fused_zero1_multimem_bf16_kernel( - uint64_t g_multicast_base, - uint64_t p_multicast_base, - const __nv_bfloat16* __restrict__ local_w, - moment_t* __restrict__ m_part, - moment_t* __restrict__ v_part, - int64_t part_start, - int64_t part_size, - int world_size, - float lr, float beta1, float beta2, float eps, - float bc1, float bc2 -) { - int64_t idx_8 = ((int64_t)blockIdx.x * blockDim.x + threadIdx.x) * 8; - float scale_g = 1.0f / (float)world_size; - - if (idx_8 + 7 < part_size) { - int64_t global_idx = part_start + idx_8; - - uint64_t* g_ptr = reinterpret_cast(g_multicast_base + global_idx * 2); - uint64_t* p_ptr = reinterpret_cast(p_multicast_base + global_idx * 2); - - uint32_t gx, gy, gz, gw; - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(gx), "=r"(gy), "=r"(gz), "=r"(gw) - : "l"(g_ptr) - : "memory"); - - const uint32_t* local_w_u32 = reinterpret_cast(&local_w[global_idx]); - uint32_t wx = local_w_u32[0]; - uint32_t wy = local_w_u32[1]; - uint32_t wz = local_w_u32[2]; - uint32_t ww = local_w_u32[3]; - - float2 w01 = do_adam_bf16x2_tmpl(gx, wx, &m_part[idx_8], &v_part[idx_8], scale_g, lr, beta1, beta2, eps, bc1, bc2); - float2 w23 = do_adam_bf16x2_tmpl(gy, wy, &m_part[idx_8+2], &v_part[idx_8+2], scale_g, lr, beta1, beta2, eps, bc1, bc2); - float2 w45 = do_adam_bf16x2_tmpl(gz, wz, &m_part[idx_8+4], &v_part[idx_8+4], scale_g, lr, beta1, beta2, eps, bc1, bc2); - float2 w67 = do_adam_bf16x2_tmpl(gw, ww, &m_part[idx_8+6], &v_part[idx_8+6], scale_g, lr, beta1, beta2, eps, bc1, bc2); - - uint32_t out_x = pack_bf16x2(w01); - uint32_t out_y = pack_bf16x2(w23); - uint32_t out_z = pack_bf16x2(w45); - uint32_t out_w = pack_bf16x2(w67); - - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(p_ptr), "r"(out_x), "r"(out_y), "r"(out_z), "r"(out_w) - : "memory"); - } -} - -template -__global__ void fused_zero1_uva_kernel( - const long long* __restrict__ g_ptrs, - const long long* __restrict__ p_ptrs, - weight_t* __restrict__ local_w, - moment_t* __restrict__ m_part, - moment_t* __restrict__ v_part, - int64_t part_start, - int64_t part_size, - int64_t m_part_offset, - int world_size, - float lr, float beta1, float beta2, float eps, - float bc1, float bc2 -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx < part_size) { - int64_t global_idx = part_start + idx; - int64_t local_idx = m_part_offset + idx; - - float g_sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const grad_t* peer_g = (const grad_t*)g_ptrs[r]; - g_sum += (float)peer_g[global_idx]; - } - g_sum /= world_size; - - float m = (float)m_part[local_idx]; - float v = (float)v_part[local_idx]; - - m = beta1 * m + (1.0f - beta1) * g_sum; - v = beta2 * v + (1.0f - beta2) * g_sum * g_sum; - - m_part[local_idx] = (moment_t)m; - v_part[local_idx] = (moment_t)v; - - float m_hat = m / bc1; - float v_hat = v / bc2; - - float w = (float)local_w[global_idx]; - w -= lr * m_hat / (sqrtf(v_hat) + eps); - - weight_t new_w = (weight_t)w; - - #pragma unroll - for (int r = 0; r < world_size; ++r) { - weight_t* peer_p = (weight_t*)p_ptrs[r]; - peer_p[global_idx] = new_w; - } - } -} - -__global__ void broadcast_uva_kernel_bf16( - __nv_bfloat16* local_p, const __nv_bfloat16* src_p, int64_t numel -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel; idx += (int64_t)blockDim.x * gridDim.x) { - local_p[idx] = src_p[idx]; - } -} -__global__ void broadcast_uva_kernel_f32( - float* local_p, const float* src_p, int64_t numel -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < numel; idx += (int64_t)blockDim.x * gridDim.x) { - local_p[idx] = src_p[idx]; - } -} - -void launch_fused_multimem_bf16( - uint64_t g_multicast, uint64_t p_multicast, - torch::Tensor local_w, torch::Tensor m_part, torch::Tensor v_part, - int64_t part_start, int64_t part_size, int world_size, - float lr, float beta1, float beta2, float eps, float bc1, float bc2 -) { - int threads = 256; - int blocks = (part_size / 8 + threads - 1) / threads; - if (blocks == 0) return; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (m_part.scalar_type() == at::ScalarType::Float) { - fused_zero1_multimem_bf16_kernel<<>>( - g_multicast, p_multicast, - reinterpret_cast(local_w.data_ptr()), - m_part.data_ptr(), - v_part.data_ptr(), - part_start, part_size, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ); - } else { - fused_zero1_multimem_bf16_kernel<__nv_bfloat16><<>>( - g_multicast, p_multicast, - reinterpret_cast(local_w.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(m_part.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(v_part.data_ptr()), - part_start, part_size, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ); - } -} - -template -void dispatch_uva_kernel( - torch::Tensor g_ptrs, torch::Tensor p_ptrs, - torch::Tensor local_w, torch::Tensor m_part, torch::Tensor v_part, - int64_t part_start, int64_t part_size, int64_t m_part_offset, int world_size, - float lr, float beta1, float beta2, float eps, float bc1, float bc2, cudaStream_t stream -) { - int threads = 256; - int blocks = (part_size + threads - 1) / threads; - if (blocks == 0) return; - - const long long* d_g_ptrs = (const long long*)g_ptrs.data_ptr(); - const long long* d_p_ptrs = (const long long*)p_ptrs.data_ptr(); - - fused_zero1_uva_kernel<<>>( - d_g_ptrs, d_p_ptrs, - reinterpret_cast(local_w.data_ptr()), - reinterpret_cast(m_part.data_ptr()), - reinterpret_cast(v_part.data_ptr()), - part_start, part_size, m_part_offset, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ); -} - -void launch_fused_uva( - torch::Tensor g_ptrs, torch::Tensor p_ptrs, - torch::Tensor local_w, torch::Tensor m_part, torch::Tensor v_part, - int64_t part_start, int64_t part_size, int64_t m_part_offset, int world_size, - float lr, float beta1, float beta2, float eps, float bc1, float bc2 -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (local_w.scalar_type() == at::ScalarType::BFloat16) { - if (m_part.scalar_type() == at::ScalarType::Float) { - dispatch_uva_kernel<__nv_bfloat16, float>(g_ptrs, p_ptrs, local_w, m_part, v_part, part_start, part_size, m_part_offset, world_size, lr, beta1, beta2, eps, bc1, bc2, stream); - } else { - dispatch_uva_kernel<__nv_bfloat16, __nv_bfloat16>(g_ptrs, p_ptrs, local_w, m_part, v_part, part_start, part_size, m_part_offset, world_size, lr, beta1, beta2, eps, bc1, bc2, stream); - } - } else { - if (m_part.scalar_type() == at::ScalarType::Float) { - dispatch_uva_kernel(g_ptrs, p_ptrs, local_w, m_part, v_part, part_start, part_size, m_part_offset, world_size, lr, beta1, beta2, eps, bc1, bc2, stream); - } - } -} - -void launch_uva_broadcast(torch::Tensor local_p, int64_t src_ptr, int64_t numel) { - int threads = 256; - int blocks = (numel + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (local_p.scalar_type() == at::ScalarType::BFloat16) { - broadcast_uva_kernel_bf16<<>>( - reinterpret_cast<__nv_bfloat16*>(local_p.data_ptr()), - reinterpret_cast(src_ptr), - numel - ); - } else { - broadcast_uva_kernel_f32<<>>( - local_p.data_ptr(), - reinterpret_cast(src_ptr), - numel - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_multimem_bf16", &launch_fused_multimem_bf16, "Fused Multimem BF16 Kernel"); - m.def("launch_fused_uva", &launch_fused_uva, "Fused UVA Kernel"); - m.def("launch_uva_broadcast", &launch_uva_broadcast, "UVA Broadcast Kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_zero1_opt_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(n: int, dtype: torch.dtype, device: torch.device): - key = (n, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(n, device=device, dtype=dtype) - p_hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - p_ptrs = torch.tensor(p_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - g_buf = symm_mem.empty(n, device=device, dtype=dtype) - g_hdl = symm_mem.rendezvous(g_buf, dist.group.WORLD) - g_ptrs = torch.tensor(g_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, p_hdl, g_buf, g_hdl, p_ptrs, g_ptrs) - _symm_cache[key] = res - return res - -def solution( - X_local: Tensor, - y_local: Tensor, - W1: Tensor, - b1: Tensor, - W2: Tensor, - b2: Tensor, - exp_avg_part: Tensor, - exp_avg_sq_part: Tensor, - lr: float, - beta1: float, - beta2: float, - eps: float, - step: int, -) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - assert dist.is_initialized(), "torch.distributed must be initialized" - - rank = dist.get_rank() - world_size = dist.get_world_size() - - templates = [W1, b1, W2, b2] - flat_p = _flatten_dense_tensors(templates) - numel = flat_p.numel() - part = exp_avg_part.numel() - start = rank * part - assert numel == part * world_size - - buf, p_hdl, g_buf, g_hdl, p_ptrs, g_ptrs = _get_symm_state(numel, flat_p.dtype, flat_p.device) - - # 1. Sync & ensure weights match Rank 0 (replaces initial dist.broadcast) - p_hdl.barrier(channel=0) - buf.copy_(flat_p) - p_hdl.barrier(channel=1) - if rank != 0: - _get_ext().launch_uva_broadcast(buf, p_ptrs[0].item(), numel) - p_hdl.barrier(channel=2) - - # 2. Forward / Backward Pass - param_views = _unflatten_dense_tensors(buf, templates) - params = [t.detach().requires_grad_(True) for t in param_views] - - h = F.relu(F.linear(X_local, params[0], params[1])) - out = F.linear(h, params[2], params[3]) - loss = F.mse_loss(out, y_local) - loss.backward() - - # 3. Flatten Grads into Symm Mem - flat_g = _flatten_dense_tensors([p.grad for p in params]) - g_hdl.barrier(channel=0) - g_buf.copy_(flat_g) - g_hdl.barrier(channel=1) - - # 4. Fused Reduce-Scatter + Adam + All-Gather - m_part = exp_avg_part.clone() - v_part = exp_avg_sq_part.clone() - - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - - use_multimem = (buf.dtype == torch.bfloat16) and getattr(g_hdl, 'multicast_ptr', 0) != 0 and getattr(p_hdl, 'multicast_ptr', 0) != 0 - - if use_multimem: - numel_8 = part // 8 - if numel_8 > 0: - _get_ext().launch_fused_multimem_bf16( - int(g_hdl.multicast_ptr), int(p_hdl.multicast_ptr), - buf, m_part, v_part, - start, numel_8 * 8, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ) - remainder = part % 8 - if remainder > 0: - _get_ext().launch_fused_uva( - g_ptrs, p_ptrs, - buf, m_part, v_part, - start + numel_8 * 8, remainder, numel_8 * 8, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ) - else: - _get_ext().launch_fused_uva( - g_ptrs, p_ptrs, - buf, m_part, v_part, - start, part, 0, world_size, - lr, beta1, beta2, eps, bc1, bc2 - ) - - p_hdl.barrier(channel=3) - - out_params = _unflatten_dense_tensors(buf, templates) - return (*out_params, m_part, v_part) - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/43_zero2_optimizer_shard_grad_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/43_zero2_optimizer_shard_grad_cuda.py deleted file mode 100755 index aa18d9d..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/43_zero2_optimizer_shard_grad_cuda.py +++ /dev/null @@ -1,391 +0,0 @@ -import math -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Initial P2P Broadcast -// --------------------------------------------------------------------------- - -__global__ void p2p_copy_kernel(const __nv_bfloat16* src, __nv_bfloat16* dst, int64_t n) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += gridDim.x * blockDim.x) { - dst[idx] = src[idx]; - } -} - -void p2p_copy(int64_t src_ptr, torch::Tensor dst, int64_t n) { - const __nv_bfloat16* src = reinterpret_cast(src_ptr); - __nv_bfloat16* d = reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()); - int threads = 512; - int blocks = std::min(65535, (n + threads - 1) / threads); - p2p_copy_kernel<<>>(src, d, n); -} - -// --------------------------------------------------------------------------- -// Math and Multimem Intrinsics -// --------------------------------------------------------------------------- - -__device__ __forceinline__ float2 unpack_bf16x2(uint32_t v) { - __nv_bfloat162 tmp = *reinterpret_cast<__nv_bfloat162*>(&v); - return __bfloat1622float2(tmp); -} - -__device__ __forceinline__ uint32_t pack_bf16x2(float2 v) { - __nv_bfloat162 tmp = __float22bfloat162_rn(v); - return *reinterpret_cast(&tmp); -} - -template -__device__ __forceinline__ float load_state(StateT* ptr, int64_t idx); - -template <> -__device__ __forceinline__ float load_state(float* ptr, int64_t idx) { - return ptr[idx]; -} - -template <> -__device__ __forceinline__ float load_state<__nv_bfloat16>(__nv_bfloat16* ptr, int64_t idx) { - return __bfloat162float(ptr[idx]); -} - -template -__device__ __forceinline__ void store_state(StateT* ptr, int64_t idx, float val); - -template <> -__device__ __forceinline__ void store_state(float* ptr, int64_t idx, float val) { - ptr[idx] = val; -} - -template <> -__device__ __forceinline__ void store_state<__nv_bfloat16>(__nv_bfloat16* ptr, int64_t idx, float val) { - ptr[idx] = __float2bfloat16(val); -} - -template -__device__ __forceinline__ void process_8_elements_generic( - uint32_t g_val, uint32_t w_val, - float* m_vals, float* v_vals, - uint32_t& w_out, - float inv_world_size, float lr, float beta1, float beta2, float eps, float bc1, float bc2 -) { - float2 g = unpack_bf16x2(g_val); - g.x *= inv_world_size; - g.y *= inv_world_size; - - float2 w = unpack_bf16x2(w_val); - - // Element 1 - float m_x = m_vals[0] * beta1 + g.x * (1.0f - beta1); - float v_x = v_vals[0] * beta2 + g.x * g.x * (1.0f - beta2); - m_vals[0] = m_x; - v_vals[0] = v_x; - w.x += (m_x / bc1) / (sqrtf(v_x / bc2) + eps) * (-lr); - - // Element 2 - float m_y = m_vals[1] * beta1 + g.y * (1.0f - beta1); - float v_y = v_vals[1] * beta2 + g.y * g.y * (1.0f - beta2); - m_vals[1] = m_y; - v_vals[1] = v_y; - w.y += (m_y / bc1) / (sqrtf(v_y / bc2) + eps) * (-lr); - - w_out = pack_bf16x2(w); -} - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4( - const uint64_t* addr, - uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3 -) { - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(addr) - : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4( - const uint64_t* addr, - uint32_t x, uint32_t y, uint32_t z, uint32_t w -) { - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) - : "memory"); -} - -// --------------------------------------------------------------------------- -// Fused kernels -// --------------------------------------------------------------------------- - -template -__global__ void fused_multimem_kernel( - uint64_t multicast_grad_ptr, - uint64_t multicast_weight_ptr, - const __nv_bfloat16* __restrict__ local_w_part, - StateT* __restrict__ m_part, - StateT* __restrict__ v_part, - int64_t part_128, - float lr, float beta1, float beta2, float eps, float bc1, float bc2, float inv_world_size, - int rank -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = gridDim.x * blockDim.x; - - for (int64_t i = idx; i < part_128; i += stride) { - int64_t global_idx = rank * part_128 + i; - uint64_t* g_ptr = reinterpret_cast(multicast_grad_ptr) + global_idx * 2; - uint64_t* w_ptr = reinterpret_cast(multicast_weight_ptr) + global_idx * 2; - - uint32_t g0, g1, g2, g3; - multimem_ld_reduce_bf16x4(g_ptr, g0, g1, g2, g3); - - const uint32_t* local_w = reinterpret_cast(local_w_part) + i * 4; - uint32_t w0 = local_w[0], w1 = local_w[1], w2 = local_w[2], w3 = local_w[3]; - - int64_t m_offset = i * 8; - float m_vals[8], v_vals[8]; - #pragma unroll - for (int j = 0; j < 8; ++j) { - m_vals[j] = load_state(m_part, m_offset + j); - v_vals[j] = load_state(v_part, m_offset + j); - } - - uint32_t w0_out, w1_out, w2_out, w3_out; - process_8_elements_generic(g0, w0, &m_vals[0], &v_vals[0], w0_out, inv_world_size, lr, beta1, beta2, eps, bc1, bc2); - process_8_elements_generic(g1, w1, &m_vals[2], &v_vals[2], w1_out, inv_world_size, lr, beta1, beta2, eps, bc1, bc2); - process_8_elements_generic(g2, w2, &m_vals[4], &v_vals[4], w2_out, inv_world_size, lr, beta1, beta2, eps, bc1, bc2); - process_8_elements_generic(g3, w3, &m_vals[6], &v_vals[6], w3_out, inv_world_size, lr, beta1, beta2, eps, bc1, bc2); - - #pragma unroll - for (int j = 0; j < 8; ++j) { - store_state(m_part, m_offset + j, m_vals[j]); - store_state(v_part, m_offset + j, v_vals[j]); - } - - multimem_st_bf16x4(w_ptr, w0_out, w1_out, w2_out, w3_out); - } -} - -template -__global__ void p2p_fused_kernel( - const uint64_t* __restrict__ peer_grad_ptrs, - const uint64_t* __restrict__ peer_weight_ptrs, - __nv_bfloat16* __restrict__ local_w_part, - StateT* __restrict__ m_part, - StateT* __restrict__ v_part, - int64_t part, - float lr, float beta1, float beta2, float eps, float bc1, float bc2, float inv_world_size, - int rank, int world_size -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = gridDim.x * blockDim.x; - - for (int64_t i = idx; i < part; i += stride) { - int64_t global_idx = rank * part + i; - - float sum_g = 0.0f; - for (int r = 0; r < world_size; ++r) { - __nv_bfloat16* peer_g = reinterpret_cast<__nv_bfloat16*>(peer_grad_ptrs[r]); - sum_g += __bfloat162float(peer_g[global_idx]); - } - float g = sum_g * inv_world_size; - - float m = load_state(m_part, i); - float v = load_state(v_part, i); - - m = m * beta1 + g * (1.0f - beta1); - v = v * beta2 + g * g * (1.0f - beta2); - - store_state(m_part, i, m); - store_state(v_part, i, v); - - float m_hat = m / bc1; - float v_hat = v / bc2; - - float w = __bfloat162float(local_w_part[i]); - w += m_hat / (sqrtf(v_hat) + eps) * (-lr); - __nv_bfloat16 new_w = __float2bfloat16(w); - - for (int r = 0; r < world_size; ++r) { - __nv_bfloat16* peer_w = reinterpret_cast<__nv_bfloat16*>(peer_weight_ptrs[r]); - peer_w[global_idx] = new_w; - } - } -} - -void fused_step( - int64_t multicast_grad_ptr, - int64_t multicast_weight_ptr, - torch::Tensor grad_ptrs, - torch::Tensor weight_ptrs, - torch::Tensor weight_buf, - torch::Tensor m_part, - torch::Tensor v_part, - int64_t part, - float lr, float beta1, float beta2, float eps, float bc1, float bc2, float inv_world_size, - int rank, int world_size -) { - bool use_multimem = (multicast_grad_ptr != 0) && (multicast_weight_ptr != 0) && (part % 8 == 0); - - int threads = 512; - int blocks = std::min(65535, (int)((part + threads - 1) / threads)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const uint64_t* d_g_ptrs = reinterpret_cast(grad_ptrs.data_ptr()); - const uint64_t* d_w_ptrs = reinterpret_cast(weight_ptrs.data_ptr()); - __nv_bfloat16* local_w_part = reinterpret_cast<__nv_bfloat16*>(weight_buf.data_ptr()) + rank * part; - - if (m_part.dtype() == torch::kFloat32) { - float* m = m_part.data_ptr(); - float* v = v_part.data_ptr(); - if (use_multimem) { - int64_t part_128 = part / 8; - int blocks128 = std::min(65535, (int)((part_128 + threads - 1) / threads)); - fused_multimem_kernel<<>>( - multicast_grad_ptr, multicast_weight_ptr, local_w_part, m, v, part_128, - lr, beta1, beta2, eps, bc1, bc2, inv_world_size, rank - ); - } else { - p2p_fused_kernel<<>>( - d_g_ptrs, d_w_ptrs, local_w_part, m, v, part, - lr, beta1, beta2, eps, bc1, bc2, inv_world_size, rank, world_size - ); - } - } else if (m_part.dtype() == torch::kBFloat16) { - __nv_bfloat16* m = reinterpret_cast<__nv_bfloat16*>(m_part.data_ptr()); - __nv_bfloat16* v = reinterpret_cast<__nv_bfloat16*>(v_part.data_ptr()); - if (use_multimem) { - int64_t part_128 = part / 8; - int blocks128 = std::min(65535, (int)((part_128 + threads - 1) / threads)); - fused_multimem_kernel<__nv_bfloat16><<>>( - multicast_grad_ptr, multicast_weight_ptr, local_w_part, m, v, part_128, - lr, beta1, beta2, eps, bc1, bc2, inv_world_size, rank - ); - } else { - p2p_fused_kernel<__nv_bfloat16><<>>( - d_g_ptrs, d_w_ptrs, local_w_part, m, v, part, - lr, beta1, beta2, eps, bc1, bc2, inv_world_size, rank, world_size - ); - } - } else { - TORCH_CHECK(false, "Unsupported dtype for m_part/v_part"); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("p2p_copy", &p2p_copy, "P2P symmetric memory copy"); - m.def("fused_step", &fused_step, "Fused Multimem Reduce-Scatter, Adam, All-Gather"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("zero2_fused_opt_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_resources(n: int, dtype: torch.dtype, device: torch.device): - key = (n, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - grad_buf = symm_mem.empty(n, dtype=dtype, device=device) - hdl_g = symm_mem.rendezvous(grad_buf, dist.group.WORLD) - grad_ptrs = torch.tensor(hdl_g.buffer_ptrs, dtype=torch.int64, device=device) - - weight_buf = symm_mem.empty(n, dtype=dtype, device=device) - hdl_w = symm_mem.rendezvous(weight_buf, dist.group.WORLD) - weight_ptrs = torch.tensor(hdl_w.buffer_ptrs, dtype=torch.int64, device=device) - - res = (grad_buf, hdl_g, grad_ptrs, weight_buf, hdl_w, weight_ptrs) - _symm_cache[key] = res - return res - -def solution( - X_local: Tensor, y_local: Tensor, - W1: Tensor, b1: Tensor, W2: Tensor, b2: Tensor, - exp_avg_part: Tensor, exp_avg_sq_part: Tensor, - lr: float, beta1: float, beta2: float, eps: float, step: int, -) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - world_size = dist.get_world_size() - rank = dist.get_rank() - - templates = [W1, b1, W2, b2] - flat_p = _flatten_dense_tensors(templates) - assert flat_p.dtype == torch.bfloat16, "Kernel is highly optimized for BF16 weights and gradients" - - n = flat_p.numel() - part = exp_avg_part.numel() - grad_buf, hdl_g, grad_ptrs, weight_buf, hdl_w, weight_ptrs = _get_symm_resources(n, flat_p.dtype, flat_p.device) - - # Fast initial broadcast to sync weights via P2P - if rank == 0: - weight_buf.copy_(flat_p) - hdl_w.barrier(channel=0) - if rank != 0: - _get_ext().p2p_copy(weight_ptrs[0].item(), weight_buf, n) - hdl_w.barrier(channel=1) - - # Establish PyTorch parameters backed directly by symmetric memory - param_views = _unflatten_dense_tensors(weight_buf, templates) - params = [t.detach().requires_grad_(True) for t in param_views] - - # Forward / Backward pass - h = F.relu(F.linear(X_local, params[0], params[1])) - out = F.linear(h, params[2], params[3]) - loss = F.mse_loss(out, y_local) - loss.backward() - - # Flatten gradient mappings seamlessly to our symmetric buffer - flat_g = _flatten_dense_tensors([p.grad for p in params]).contiguous() - grad_buf.copy_(flat_g) - hdl_g.barrier(channel=0) - - # Prepare local partition for optimizer step - m_part = exp_avg_part.clone() - v_part = exp_avg_sq_part.clone() - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - inv_world_size = 1.0 / world_size - - multicast_grad_ptr = int(hdl_g.multicast_ptr) if hdl_g.has_multicast else 0 - multicast_weight_ptr = int(hdl_w.multicast_ptr) if hdl_w.has_multicast else 0 - - # Dispatched fused kernel: Hardware Reduce-Scatter -> Adam -> Hardware Broadcast Update - _get_ext().fused_step( - multicast_grad_ptr, - multicast_weight_ptr, - grad_ptrs, - weight_ptrs, - weight_buf, - m_part, - v_part, - part, - lr, beta1, beta2, eps, bc1, bc2, inv_world_size, - rank, world_size - ) - - # Barrier ensures execution finishes cleanly before tensors are copied off the persistent buffer - hdl_w.barrier(channel=2) - - out_flat_p = weight_buf.clone() - out_params = _unflatten_dense_tensors(out_flat_p, templates) - - return (*out_params, m_part, v_part) - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/44_fused_adam_grad_unshard_allgather_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/44_fused_adam_grad_unshard_allgather_cuda.py deleted file mode 100755 index fec2ae5..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/44_fused_adam_grad_unshard_allgather_cuda.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -Strategy: -1. Fused Adam + Chunked Pull-based AllGather: We use a single custom CUDA kernel to overlap the parameter update with AllGather communication, maximizing GPU utilization. -2. Push-based Synchronization: Each block computes its local Adam shard and immediately pushes a completion step-counter flag to all peers via symmetric memory. -3. Fast Spin-Wait: Peers wait on their *local* symmetric flag buffer, minimizing NVLink polling traffic. -4. Symmetric Exchange Buffer: We allocate a shared symmetric exchange buffer of size `P` to hold the updated local shard, which peers then pull into their final output tensor. This keeps memory overhead to a bare minimum `O(P)` instead of caching `O(world_size * P)`. -5. Device-side Barrier: We use `symm_mem` channel barriers at the end of the step to safely allow buffer reuse across sequential optimizer calls without blocking the CPU. -""" - -import math -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -struct AdamMath; - -template <> -struct AdamMath { - static __device__ __forceinline__ float to_float(float x) { return x; } - static __device__ __forceinline__ float from_float(float x) { return x; } -}; - -template <> -struct AdamMath<__nv_bfloat16> { - static __device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } - static __device__ __forceinline__ __nv_bfloat16 from_float(float x) { return __float2bfloat16(x); } -}; - -template -__global__ void fused_adam_allgather_kernel( - const scalar_t* __restrict__ g, - scalar_t* __restrict__ m, - scalar_t* __restrict__ v, - scalar_t* __restrict__ w, - scalar_t* __restrict__ local_gathered, - scalar_t* __restrict__ my_exchange, - const uint64_t* __restrict__ flag_ptrs, - const uint64_t* __restrict__ exchange_ptrs, - float beta1, - float beta2, - float lr, - float eps, - float bc1, - float bc2, - int64_t P, - int W, - int B, - int my_rank, - int step -) { - int b_global = blockIdx.x; - int r = b_global / B; - int local_b = b_global % B; - - int64_t align_elements = 16 / sizeof(scalar_t); - int64_t chunk_size = (P + B - 1) / B; - chunk_size = ((chunk_size + align_elements - 1) / align_elements) * align_elements; - - int64_t start = local_b * chunk_size; - int64_t end = start + chunk_size; - if (end > P) end = P; - - if (r == my_rank) { - // 1. Compute Adam on our rank's shard - if (start < P) { - for (int64_t i = start + threadIdx.x; i < end; i += blockDim.x) { - float gi = AdamMath::to_float(g[i]); - float mi = AdamMath::to_float(m[i]); - float vi = AdamMath::to_float(v[i]); - float wi = AdamMath::to_float(w[i]); - - mi = mi * beta1 + gi * (1.0f - beta1); - vi = vi * beta2 + gi * gi * (1.0f - beta2); - - float m_hat = mi / bc1; - float v_hat = vi / bc2; - - wi += m_hat / (sqrtf(v_hat) + eps) * (-lr); - - scalar_t out_val = AdamMath::from_float(wi); - m[i] = AdamMath::from_float(mi); - v[i] = AdamMath::from_float(vi); - w[i] = out_val; - - local_gathered[r * P + i] = out_val; - my_exchange[i] = out_val; - } - } - - __syncthreads(); - // 2. Push signal to all peers that this chunk is ready - if (threadIdx.x == 0) { - __threadfence_system(); - for (int p = 0; p < W; ++p) { - volatile int* peer_flag_ptr = reinterpret_cast(flag_ptrs[p]); - peer_flag_ptr[my_rank * B + local_b] = step; - } - } - } else { - // 1. Spin-wait on LOCAL flag memory for the peer to finish - if (threadIdx.x == 0) { - int* my_flag_ptr = reinterpret_cast(flag_ptrs[my_rank]); - volatile int* wait_flag = (volatile int*)(&my_flag_ptr[r * B + local_b]); - while (*wait_flag < step) { -#if __CUDA_ARCH__ >= 700 - asm volatile("nanosleep.u32 20;" ::: "memory"); -#endif - } - __threadfence_system(); - } - __syncthreads(); - - // 2. Pull data from peer's exchange buffer via UVA - if (start < P) { - const scalar_t* peer_exchange = reinterpret_cast(exchange_ptrs[r]); - int64_t n = end - start; - const scalar_t* src_ptr = peer_exchange + start; - scalar_t* dst_ptr = local_gathered + r * P + start; - - int64_t i = threadIdx.x; - if (((uintptr_t)src_ptr % 16 == 0) && ((uintptr_t)dst_ptr % 16 == 0)) { - int64_t n_vec = n / align_elements; - const ulong2* src_vec = reinterpret_cast(src_ptr); - ulong2* dst_vec = reinterpret_cast(dst_ptr); - for (int64_t vi = threadIdx.x; vi < n_vec; vi += blockDim.x) { - dst_vec[vi] = src_vec[vi]; - } - i = n_vec * align_elements + threadIdx.x; - } - for (; i < n; i += blockDim.x) { - dst_ptr[i] = src_ptr[i]; - } - } - } -} - -void launch_bf16( - torch::Tensor g, torch::Tensor m, torch::Tensor v, torch::Tensor w, - torch::Tensor local_gathered, torch::Tensor my_exchange, - torch::Tensor flag_ptrs, torch::Tensor exchange_ptrs, - float beta1, float beta2, float lr, float eps, float bc1, float bc2, - int64_t P, int W, int B, int my_rank, int step, - int blocks, int threads -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - fused_adam_allgather_kernel<__nv_bfloat16><<>>( - reinterpret_cast(g.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(m.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(v.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(w.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(local_gathered.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(my_exchange.data_ptr()), - reinterpret_cast(flag_ptrs.data_ptr()), - reinterpret_cast(exchange_ptrs.data_ptr()), - beta1, beta2, lr, eps, bc1, bc2, - P, W, B, my_rank, step - ); -} - -void launch_fp32( - torch::Tensor g, torch::Tensor m, torch::Tensor v, torch::Tensor w, - torch::Tensor local_gathered, torch::Tensor my_exchange, - torch::Tensor flag_ptrs, torch::Tensor exchange_ptrs, - float beta1, float beta2, float lr, float eps, float bc1, float bc2, - int64_t P, int W, int B, int my_rank, int step, - int blocks, int threads -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - fused_adam_allgather_kernel<<>>( - g.data_ptr(), m.data_ptr(), v.data_ptr(), w.data_ptr(), - local_gathered.data_ptr(), my_exchange.data_ptr(), - reinterpret_cast(flag_ptrs.data_ptr()), - reinterpret_cast(exchange_ptrs.data_ptr()), - beta1, beta2, lr, eps, bc1, bc2, - P, W, B, my_rank, step - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_bf16", &launch_bf16, "Fused Adam + AllGather for bfloat16"); - m.def("launch_fp32", &launch_fp32, "Fused Adam + AllGather for float32"); -} -''' - -_ext = None -_ext_compiled = False - -def _ensure_ext(): - global _ext, _ext_compiled - if not _ext_compiled: - if dist.get_rank() == 0: - _ext = compile_cuda_extension("fused_adam_allgather_ext", CUDA_SRC) - dist.barrier() - if dist.get_rank() != 0: - _ext = compile_cuda_extension("fused_adam_allgather_ext", CUDA_SRC) - _ext_compiled = True - return _ext - -_exchange_state = None -_sync_step = 1 - -def _get_exchange_state(P: int, dtype: torch.dtype, device: torch.device, W: int): - global _exchange_state - - if _exchange_state is None or _exchange_state['P'] < P or _exchange_state['dtype'] != dtype: - new_P = max(P, _exchange_state['P'] if _exchange_state else 0) - - buf = symm_mem.empty(new_P, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - flags = symm_mem.empty(W * 128, dtype=torch.int32, device=device) - flags.zero_() - flags_hdl = symm_mem.rendezvous(flags, dist.group.WORLD) - flag_ptrs = torch.tensor(flags_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _exchange_state = { - 'P': new_P, 'dtype': dtype, - 'buf': buf, 'hdl': hdl, 'ptrs': ptrs, - 'flags': flags, 'flags_hdl': flags_hdl, 'flag_ptrs': flag_ptrs - } - - return _exchange_state - - -@torch.no_grad() -def solution( - grad_shard: Tensor, - master_shard: Tensor, - exp_avg: Tensor, - exp_avg_sq: Tensor, - lr: float, - beta1: float, - beta2: float, - eps: float, - step: int, -) -> Tensor: - global _sync_step - - W = dist.get_world_size() - rank = dist.get_rank() - P = grad_shard.numel() - device = grad_shard.device - dtype = master_shard.dtype - - ext = _ensure_ext() - state = _get_exchange_state(P, dtype, device, W) - - g = grad_shard.contiguous() - m = exp_avg.contiguous() - v = exp_avg_sq.contiguous() - w = master_shard.contiguous() - - gathered = torch.empty(W * P, dtype=dtype, device=device) - - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - - B = 128 - threads = 256 - blocks = W * B - - if dtype == torch.bfloat16: - ext.launch_bf16( - g, m, v, w, - gathered, state['buf'], state['flag_ptrs'], state['ptrs'], - beta1, beta2, lr, eps, bc1, bc2, - P, W, B, rank, _sync_step, - blocks, threads - ) - elif dtype == torch.float32: - ext.launch_fp32( - g, m, v, w, - gathered, state['buf'], state['flag_ptrs'], state['ptrs'], - beta1, beta2, lr, eps, bc1, bc2, - P, W, B, rank, _sync_step, - blocks, threads - ) - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - _sync_step += 1 - - # Fast device-side barrier prevents proceeding CPU streams from - # enqueuing kernels that might overwrite the reused `exchange_buf` - # before all peer pulls have finished asynchronously. - state['hdl'].barrier(channel=0) - - return gathered - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/45_quantized_grad_allreduce_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/45_quantized_grad_allreduce_cuda.py deleted file mode 100755 index 4e44e5e..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/45_quantized_grad_allreduce_cuda.py +++ /dev/null @@ -1,290 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include -#include - -struct PeerPtrs { - const float* ptrs[8]; -}; - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) : "l"(addr) : "memory" - ); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) : "l"(addr) : "memory" - ); - } while (tmp != 1u); -} - -template -struct CudaTypeTraits; - -template<> -struct CudaTypeTraits<__nv_bfloat16> { - static __device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } - static __device__ __forceinline__ __nv_bfloat16 from_float(float x) { return __float2bfloat16(x); } -}; - -template<> -struct CudaTypeTraits { - static __device__ __forceinline__ float to_float(float x) { return x; } - static __device__ __forceinline__ float from_float(float x) { return x; } -}; - -template -__global__ void fused_quant_dequant_reduce_kernel( - const T* __restrict__ input, - float* __restrict__ symm_buf, - PeerPtrs peer_ptrs, - T* __restrict__ out, - int64_t n, - int64_t block_size, - const uint64_t* __restrict__ signal_pad_ptrs, - int world_size, - int rank, - float inv_world_size -) { - int64_t nb = (n + block_size - 1) / block_size; - int bid = blockIdx.x; - - // Persistent threadblock loop over chunks - for (int64_t chunk_idx = bid; chunk_idx < nb; chunk_idx += gridDim.x) { - int64_t start_idx = chunk_idx * block_size; - int64_t end_idx = start_idx + block_size; - if (end_idx > n) end_idx = n; - - // --- 1. Compute Max for Chunk --- - float local_max = 0.0f; - for (int64_t i = start_idx + threadIdx.x; i < end_idx; i += blockDim.x) { - float val = CudaTypeTraits::to_float(input[i]); - val = fabsf(val); - if (val > local_max) local_max = val; - } - - unsigned int mask = 0xffffffff; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_max = fmaxf(local_max, __shfl_down_sync(mask, local_max, offset)); - } - - __shared__ float warp_max[32]; - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - - if (lane_id == 0) warp_max[warp_id] = local_max; - __syncthreads(); - - float block_max = 0.0f; - if (warp_id == 0) { - block_max = (lane_id < (blockDim.x / 32)) ? warp_max[lane_id] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - block_max = fmaxf(block_max, __shfl_down_sync(mask, block_max, offset)); - } - if (lane_id == 0) { - if (block_max < 1e-8f) block_max = 1e-8f; - warp_max[0] = block_max / 127.0f; - } - } - __syncthreads(); - - // --- 2. Quantize & Dequantize --- - float scale = warp_max[0]; - float inv_scale = 1.0f / scale; - - for (int64_t i = start_idx + threadIdx.x; i < end_idx; i += blockDim.x) { - float val = CudaTypeTraits::to_float(input[i]); - float q = roundf(val * inv_scale); - if (q > 127.0f) q = 127.0f; - if (q < -127.0f) q = -127.0f; - symm_buf[i] = q * scale; - } - __syncthreads(); - - // --- 3. Chunk-level Device Barrier --- - // channel_id avoids overlapping channel 0 which is used for global PyTorch barriers. - uint64_t channel_id = 1 + (chunk_idx % 65535); - if (threadIdx.x < world_size) { - unsigned int flat_tid = threadIdx.x; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - - uint32_t* send_addr = reinterpret_cast( - remote_base + channel_id * (uint64_t)world_size * 4 + (uint64_t)rank * 4); - uint32_t* wait_addr = reinterpret_cast( - local_base + channel_id * (uint64_t)world_size * 4 + (uint64_t)flat_tid * 4); - - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); - } - __syncthreads(); - - // --- 4. Peer-to-Peer FP32 Reduce --- - for (int64_t i = start_idx + threadIdx.x; i < end_idx; i += blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < 8; ++r) { - if (r < world_size) { - sum += peer_ptrs.ptrs[r][i]; - } - } - sum *= inv_world_size; - out[i] = CudaTypeTraits::from_float(sum); - } - __syncthreads(); - } -} - -void launch_fused_quant_reduce( - torch::Tensor input, - torch::Tensor symm_buf, - std::vector ptrs, - torch::Tensor out, - int64_t block_size, - torch::Tensor signal_pad_ptrs_tensor, - int world_size, - int rank -) { - TORCH_CHECK(world_size <= 8, "world_size > 8 is not supported"); - - int64_t n = input.numel(); - int64_t nb = (n + block_size - 1) / block_size; - - int threads = 256; - int blocks = nb < 132 ? nb : 132; - - PeerPtrs peer_ptrs; - for (int i = 0; i < world_size; i++) { - peer_ptrs.ptrs[i] = reinterpret_cast(ptrs[i]); - } - - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - float inv_world_size = 1.0f / world_size; - - if (input.dtype() == torch::kBFloat16) { - fused_quant_dequant_reduce_kernel<__nv_bfloat16><<>>( - reinterpret_cast(input.data_ptr()), - symm_buf.data_ptr(), - peer_ptrs, - reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), - n, - block_size, - d_signal, - world_size, - rank, - inv_world_size - ); - } else if (input.dtype() == torch::kFloat32) { - fused_quant_dequant_reduce_kernel<<>>( - input.data_ptr(), - symm_buf.data_ptr(), - peer_ptrs, - out.data_ptr(), - n, - block_size, - d_signal, - world_size, - rank, - inv_world_size - ); - } else { - TORCH_CHECK(false, "Unsupported dtype. Only BF16 and FP32 are supported."); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_quant_reduce", &launch_fused_quant_reduce, "Fused quant-dequant and chunked P2P reduce"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_quant_reduce_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(n: int, dtype: torch.dtype, device: torch.device): - key = (n, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - # symm_buf must be explicitly float32 to perfectly align with reference FP32 intermediate accumulation. - buf = symm_mem.empty(n, device=device, dtype=torch.float32) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - out = torch.empty(n, device=device, dtype=dtype) - - res = (buf, hdl, out) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution( - flat_grad: torch.Tensor, - block_size: int, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert block_size >= 1 - - world_size = dist.get_world_size() - assert world_size <= 8, "world_size > 8 is not supported by this NVLink optimized kernel" - - orig_shape = flat_grad.shape - x = flat_grad.contiguous().view(-1) - n = x.numel() - dtype = x.dtype - - if n == 0: - return flat_grad.clone() - - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - - buf, hdl, out = _get_resources(n, dtype, x.device) - - # Issue a global barrier prior to starting the persistent kernel to prevent - # ranks racing ahead and overwriting the symmetric buffer before slow peers read it. - hdl.barrier(channel=0) - - ext.launch_fused_quant_reduce( - x, - buf, - hdl.buffer_ptrs, - out, - block_size, - hdl.signal_pad_ptrs_dev, - world_size, - dist.get_rank() - ) - - return out.view(orig_shape) - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/46_reducescatter_fused_rmsnorm_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/46_reducescatter_fused_rmsnorm_cuda.py deleted file mode 100755 index 7a4fef6..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/46_reducescatter_fused_rmsnorm_cuda.py +++ /dev/null @@ -1,307 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -using bf16x8 = float4; // 16 bytes for 8 bfloat16s - -__global__ void fused_rs_rmsnorm_kernel( - const long long* __restrict__ peer_ptrs, - __nv_bfloat16* __restrict__ out, - const __nv_bfloat16* __restrict__ gamma, - int world_size, - int rank, - int chunk, - int hidden, - int rows, - float eps -) { - int i = blockIdx.x; // process one row per block - if (i >= rows) return; - - int tid = threadIdx.x; - int stride = blockDim.x; - - long long rank_offset = (long long)rank * chunk; - long long row_offset = (long long)i * hidden; - long long base_idx = rank_offset + row_offset; - - bool aligned = (hidden % 8 == 0); - int vec_hidden = aligned ? hidden / 8 : 0; - int tail_start = aligned ? hidden : 0; - - float sq_sum = 0.0f; - - // Pass 1: Reduce-scatter (sum and div), write to `out` intermediate, accumulate sq_sum - for (int k_vec = tid; k_vec < vec_hidden; k_vec += stride) { - float sums[8] = {0.0f}; - - for (int p = 0; p < world_size; ++p) { - const bf16x8* peer = reinterpret_cast(peer_ptrs[p] + base_idx); - bf16x8 vals = peer[k_vec]; - - const __nv_bfloat162* v2 = reinterpret_cast(&vals); - #pragma unroll - for (int j = 0; j < 4; ++j) { - __nv_bfloat162 pair = v2[j]; - const __nv_bfloat16* p_ptr = reinterpret_cast(&pair); - sums[j*2 + 0] += __bfloat162float(p_ptr[0]); - sums[j*2 + 1] += __bfloat162float(p_ptr[1]); - } - } - - __nv_bfloat162 out_v2[4]; - #pragma unroll - for (int j = 0; j < 4; ++j) { - float v0 = sums[j*2 + 0] / world_size; - float v1 = sums[j*2 + 1] / world_size; - - __nv_bfloat16 b0 = __float2bfloat16(v0); - __nv_bfloat16 b1 = __float2bfloat16(v1); - - __nv_bfloat162 b_pair; - __nv_bfloat16* b_ptr = reinterpret_cast<__nv_bfloat16*>(&b_pair); - b_ptr[0] = b0; - b_ptr[1] = b1; - out_v2[j] = b_pair; - - float f0 = __bfloat162float(b0); - float f1 = __bfloat162float(b1); - sq_sum += f0 * f0 + f1 * f1; - } - - bf16x8* out_vec = reinterpret_cast(out + row_offset); - out_vec[k_vec] = *reinterpret_cast(out_v2); - } - - // Tail / scalar pass - for (int k = tail_start + tid; k < hidden; k += stride) { - float sum = 0.0f; - for (int p = 0; p < world_size; ++p) { - const __nv_bfloat16* peer = reinterpret_cast(peer_ptrs[p]); - sum += __bfloat162float(peer[base_idx + k]); - } - sum /= world_size; - __nv_bfloat16 bval = __float2bfloat16(sum); - out[row_offset + k] = bval; - - float fval = __bfloat162float(bval); - sq_sum += fval * fval; - } - - // Warp and Block reduction for sq_sum - static __shared__ float shared_sq_sum[32]; // Accommodates up to 1024 threads - unsigned int mask = 0xffffffff; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sq_sum += __shfl_down_sync(mask, sq_sum, offset); - } - - int lane = tid % 32; - int wid = tid / 32; - if (lane == 0) { - shared_sq_sum[wid] = sq_sum; - } - __syncthreads(); - - float total_sq_sum = 0.0f; - int num_warps = (blockDim.x + 31) / 32; - - if (wid == 0) { - if (lane < num_warps) { - total_sq_sum = shared_sq_sum[lane]; - } - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - total_sq_sum += __shfl_down_sync(mask, total_sq_sum, offset); - } - if (lane == 0) { - shared_sq_sum[0] = total_sq_sum; - } - } - __syncthreads(); - - total_sq_sum = shared_sq_sum[0]; - float mean_sq = total_sq_sum / hidden; - float rms = rsqrtf(mean_sq + eps); - - // Pass 2: Apply RMSNorm using `out` intermediate - for (int k_vec = tid; k_vec < vec_hidden; k_vec += stride) { - bf16x8* out_vec = reinterpret_cast(out + row_offset); - bf16x8 vals = out_vec[k_vec]; - - const bf16x8* gamma_vec_ptr = reinterpret_cast(gamma); - bf16x8 g_vals = gamma_vec_ptr[k_vec]; - - const __nv_bfloat162* v2 = reinterpret_cast(&vals); - const __nv_bfloat162* g2 = reinterpret_cast(&g_vals); - - __nv_bfloat162 out_v2[4]; - #pragma unroll - for (int j = 0; j < 4; ++j) { - __nv_bfloat162 pair = v2[j]; - const __nv_bfloat16* p_ptr = reinterpret_cast(&pair); - float f_x = __bfloat162float(p_ptr[0]); - float f_y = __bfloat162float(p_ptr[1]); - - __nv_bfloat162 g_pair = g2[j]; - const __nv_bfloat16* g_ptr = reinterpret_cast(&g_pair); - float g_x = __bfloat162float(g_ptr[0]); - float g_y = __bfloat162float(g_ptr[1]); - - float v0 = f_x * rms * g_x; - float v1 = f_y * rms * g_y; - - __nv_bfloat16 b0 = __float2bfloat16(v0); - __nv_bfloat16 b1 = __float2bfloat16(v1); - - __nv_bfloat162 b_pair; - __nv_bfloat16* b_ptr = reinterpret_cast<__nv_bfloat16*>(&b_pair); - b_ptr[0] = b0; - b_ptr[1] = b1; - out_v2[j] = b_pair; - } - out_vec[k_vec] = *reinterpret_cast(out_v2); - } - - // Tail / scalar pass - for (int k = tail_start + tid; k < hidden; k += stride) { - float val = __bfloat162float(out[row_offset + k]); - float g = __bfloat162float(gamma[k]); - out[row_offset + k] = __float2bfloat16(val * rms * g); - } -} - -void launch_fused_rs_rmsnorm( - torch::Tensor ptrs_tensor, - torch::Tensor out, - torch::Tensor gamma, - int world_size, - int rank, - int chunk, - int hidden, - int rows, - float eps -) { - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - __nv_bfloat16* d_out = reinterpret_cast<__nv_bfloat16*>(out.data_ptr()); - const __nv_bfloat16* d_gamma = reinterpret_cast(gamma.data_ptr()); - - int threads = 256; - int blocks = rows; - if (blocks == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - fused_rs_rmsnorm_kernel<<>>( - d_ptrs, d_out, d_gamma, world_size, rank, chunk, hidden, rows, eps - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_rs_rmsnorm", &launch_fused_rs_rmsnorm, "Fused Reduce-Scatter and RMSNorm"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_rs_rmsnorm_ext", CUDA_SRC) - return _ext - - -_resource_cache = {} - -def _get_resources(n: int, dtype: torch.dtype, device: torch.device): - """ - Returns handles, pointer tensors, and buffers. - Uses double-buffering mapped to symm_mem to avoid blocking CPU syncs - while preventing buffer overwrites in tight recurrent loops. - """ - key = (n, dtype, device) - if key not in _resource_cache: - bufs = [] - hdls = [] - ptrs = [] - for _ in range(2): - buf = symm_mem.empty(n, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - bufs.append(buf) - hdls.append(hdl) - ptrs.append(ptrs_tensor) - _resource_cache[key] = {'bufs': bufs, 'hdls': hdls, 'ptrs': ptrs, 'idx': 0} - - cache = _resource_cache[key] - idx = cache['idx'] - cache['idx'] = (idx + 1) % 2 - return cache['bufs'][idx], cache['hdls'][idx], cache['ptrs'][idx] - - -@torch.no_grad() -def solution( - rs_input_1d: Tensor, - gamma: Tensor, - eps: float, -) -> Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - rank = dist.get_rank() - - n = rs_input_1d.numel() - assert n % world_size == 0 - chunk = n // world_size - - hidden = gamma.numel() - assert chunk % hidden == 0, f"chunk ({chunk}) must divide hidden ({hidden})" - rows = chunk // hidden - - input_bf16 = rs_input_1d.contiguous() - if input_bf16.dtype != torch.bfloat16: - input_bf16 = input_bf16.to(torch.bfloat16) - - gamma_bf16 = gamma.contiguous() - if gamma_bf16.dtype != torch.bfloat16: - gamma_bf16 = gamma_bf16.to(torch.bfloat16) - - if rank == 0: - _get_ext() - dist.barrier() # Synchronize to ensure cleanly initialized CUDA compilation step limits out-of-order execution - - buf, hdl, ptrs_tensor = _get_resources(n, torch.bfloat16, input_bf16.device) - - # Ensure current writes to symm_mem complete correctly - buf.copy_(input_bf16) - hdl.barrier(channel=0) - - out = torch.empty((rows, hidden), dtype=torch.bfloat16, device=input_bf16.device) - - _get_ext().launch_fused_rs_rmsnorm( - ptrs_tensor, - out, - gamma_bf16, - world_size, - rank, - chunk, - hidden, - rows, - eps - ) - - if out.dtype != rs_input_1d.dtype: - out = out.to(rs_input_1d.dtype) - - return out - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/47_fsdp_adamw_sharded_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/47_fsdp_adamw_sharded_cuda.py deleted file mode 100755 index 5d2a810..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/47_fsdp_adamw_sharded_cuda.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -Strategy: -- **Single Fused CUDA Kernel**: We replace the sequence of native PyTorch element-wise operations (clone, mul, add, addcmul, div, sqrt) with one fused C++ extension kernel. This eliminates multiple memory roundtrips to HBM and slashes CPU kernel launch overhead. -- **Precision Management**: Following PyTorch's mixed-precision optimizer standards, the kernel dynamically casts memory-bound types (like bfloat16 or float16) to float32 for all intermediate calculations (moments, bias correction, weight decay), writing the accurate results back out in the requested tensor dtype. -- **Maximal H100 Bandwidth**: Uses a simple grid-stride loop mapped onto the H100 SMs, leaning on the massive L2 cache for perfect memory coalescing on flat FSDP parameter shards without requiring restrictive vector alignment. -""" - -from __future__ import annotations - -import math -import torch -from torch import Tensor -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include - -template -__global__ void adamw_kernel( - const scalar_t_p* __restrict__ p_in, - const scalar_t_p* __restrict__ g_in, - const scalar_t_m* __restrict__ m_in, - const scalar_t_m* __restrict__ v_in, - scalar_t_p* __restrict__ p_out, - scalar_t_m* __restrict__ m_out, - scalar_t_m* __restrict__ v_out, - float lr, - float beta1, - float beta2, - float eps, - float weight_decay, - float bc1, - float bc2, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - #pragma unroll 4 - for (; idx < n; idx += stride) { - float p = static_cast(p_in[idx]); - float g = static_cast(g_in[idx]); - float m = static_cast(m_in[idx]); - float v = static_cast(v_in[idx]); - - // Update biased first moment estimate - m = m * beta1 + g * (1.0f - beta1); - - // Update biased second raw moment estimate - v = v * beta2 + g * g * (1.0f - beta2); - - // Compute bias-corrected moments - float m_hat = m / bc1; - float v_hat = v / bc2; - - float denom = sqrtf(v_hat) + eps; - - // Decoupled weight decay and Adam step: - // theta_new = theta - lr * (m_hat / denom) - lr * weight_decay * theta - float p_new = p - lr * (m_hat / denom); - p_new = p_new - (lr * weight_decay) * p; - - p_out[idx] = static_cast(p_new); - m_out[idx] = static_cast(m); - v_out[idx] = static_cast(v); - } -} - -void launch_adamw( - torch::Tensor p_in, - torch::Tensor g_in, - torch::Tensor m_in, - torch::Tensor v_in, - torch::Tensor p_out, - torch::Tensor m_out, - torch::Tensor v_out, - float lr, - float beta1, - float beta2, - float eps, - float weight_decay, - float bc1, - float bc2 -) { - int64_t n = p_in.numel(); - if (n == 0) return; - - // Use 512 threads per block and enough blocks to saturate H100 SMs, - // relying on the grid-stride loop for larger element counts. - const int threads = 512; - const int blocks = std::min((int)((n + threads - 1) / threads), 2048); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - // Dispatch across the common precision configurations - if (p_in.scalar_type() == torch::kBFloat16 && m_in.scalar_type() == torch::kBFloat16) { - adamw_kernel<<>>( - p_in.data_ptr(), g_in.data_ptr(), - m_in.data_ptr(), v_in.data_ptr(), - p_out.data_ptr(), m_out.data_ptr(), v_out.data_ptr(), - lr, beta1, beta2, eps, weight_decay, bc1, bc2, n - ); - } else if (p_in.scalar_type() == torch::kFloat32 && m_in.scalar_type() == torch::kFloat32) { - adamw_kernel<<>>( - p_in.data_ptr(), g_in.data_ptr(), - m_in.data_ptr(), v_in.data_ptr(), - p_out.data_ptr(), m_out.data_ptr(), v_out.data_ptr(), - lr, beta1, beta2, eps, weight_decay, bc1, bc2, n - ); - } else if (p_in.scalar_type() == torch::kBFloat16 && m_in.scalar_type() == torch::kFloat32) { - adamw_kernel<<>>( - p_in.data_ptr(), g_in.data_ptr(), - m_in.data_ptr(), v_in.data_ptr(), - p_out.data_ptr(), m_out.data_ptr(), v_out.data_ptr(), - lr, beta1, beta2, eps, weight_decay, bc1, bc2, n - ); - } else if (p_in.scalar_type() == torch::kHalf && m_in.scalar_type() == torch::kFloat32) { - adamw_kernel<<>>( - p_in.data_ptr(), g_in.data_ptr(), - m_in.data_ptr(), v_in.data_ptr(), - p_out.data_ptr(), m_out.data_ptr(), v_out.data_ptr(), - lr, beta1, beta2, eps, weight_decay, bc1, bc2, n - ); - } else if (p_in.scalar_type() == torch::kHalf && m_in.scalar_type() == torch::kHalf) { - adamw_kernel<<>>( - p_in.data_ptr(), g_in.data_ptr(), - m_in.data_ptr(), v_in.data_ptr(), - p_out.data_ptr(), m_out.data_ptr(), v_out.data_ptr(), - lr, beta1, beta2, eps, weight_decay, bc1, bc2, n - ); - } else { - TORCH_CHECK(false, "Unsupported dtype combination for Fused AdamW"); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_adamw", &launch_adamw, "Fused AdamW C++ kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_adamw_sharded_ext", CUDA_SRC) - return _ext - - -@torch.no_grad() -def solution( - flat_param_shard: Tensor, - flat_grad_shard: Tensor, - exp_avg_shard: Tensor, - exp_avg_sq_shard: Tensor, - lr: float, - beta1: float, - beta2: float, - eps: float, - weight_decay: float, - step: int, -) -> tuple[Tensor, Tensor, Tensor]: - """ - Decoupled AdamW (Loshchilov & Hutter) on one rank's shards. - """ - assert step >= 1 - assert ( - flat_param_shard.shape == flat_grad_shard.shape == exp_avg_shard.shape == exp_avg_sq_shard.shape - ) - - # Ensure tensors are contiguous and valid for the CUDA kernel pointers - flat_param_shard = flat_param_shard.contiguous() - flat_grad_shard = flat_grad_shard.contiguous() - exp_avg_shard = exp_avg_shard.contiguous() - exp_avg_sq_shard = exp_avg_sq_shard.contiguous() - - # Allocate outputs matching the out-of-place signature - out_param = torch.empty_like(flat_param_shard) - out_m = torch.empty_like(exp_avg_shard) - out_v = torch.empty_like(exp_avg_sq_shard) - - # Pre-calculate bias correction factors on the host - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - - # Dispatch to customized fused kernel - _get_ext().launch_adamw( - flat_param_shard, - flat_grad_shard, - exp_avg_shard, - exp_avg_sq_shard, - out_param, - out_m, - out_v, - lr, - beta1, - beta2, - eps, - weight_decay, - bc1, - bc2 - ) - - return out_param, out_m, out_v - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/48_fsdp_step_e2e_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/48_fsdp_step_e2e_cuda.py deleted file mode 100755 index de1e2c5..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/48_fsdp_step_e2e_cuda.py +++ /dev/null @@ -1,287 +0,0 @@ -""" -Strategy: -1. P2P All-Gather: We allocate a symmetric buffer for the full model parameters. Each rank writes its param shard, then we use a single fast custom CUDA kernel to directly fetch peers' shards over NVLink. -2. Forward/Backward Pass: Executed using PyTorch native ops directly on the unflattened views of the symmetric buffer. -3. Fused Reduce-Scatter & AdamW: We allocate a symmetric buffer for the full gradients. After PyTorch computes the full flat gradient, each rank writes it to symmetric memory. A second custom CUDA kernel reads each rank's assigned gradient slice directly from all peers, performs the reduction, and natively applies the AdamW update directly into the newly allocated output param/momentum shards in a single pass. This avoids multiple read/writes and intermediate gradient clones. -""" - -from __future__ import annotations - -import math -from typing import Sequence - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// Exactly emulate PyTorch's elementwise eager bfloat16 truncations -__device__ __forceinline__ float trunc_bf16(float x) { - return __bfloat162float(__float2bfloat16(x)); -} - -__global__ void all_gather_kernel( - const long long* __restrict__ peer_full_flats, - int64_t p, - int world_size, - int rank -) { - int64_t total = p * world_size; - for (int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - idx < total; - idx += (int64_t)gridDim.x * blockDim.x) { - - int target_rank = idx / p; - if (target_rank != rank) { - __nv_bfloat16* local_ptr = (__nv_bfloat16*)peer_full_flats[rank]; - const __nv_bfloat16* remote_ptr = (const __nv_bfloat16*)peer_full_flats[target_rank]; - local_ptr[idx] = remote_ptr[idx]; - } - } -} - -__global__ void rs_adamw_kernel( - const long long* __restrict__ peer_grads, - __nv_bfloat16* __restrict__ theta_out, - __nv_bfloat16* __restrict__ m_out, - __nv_bfloat16* __restrict__ v_out, - const __nv_bfloat16* __restrict__ theta_in, - const __nv_bfloat16* __restrict__ m_in, - const __nv_bfloat16* __restrict__ v_in, - int64_t p, - int world_size, - int rank, - float lr, - float beta1, - float beta2, - float eps, - float weight_decay, - float bc1, - float bc2 -) { - for (int64_t i = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - i < p; - i += (int64_t)gridDim.x * blockDim.x) { - - int64_t offset = rank * p + i; - float g_sum = 0.0f; - - #pragma unroll - for (int w = 0; w < world_size; ++w) { - const __nv_bfloat16* src = (const __nv_bfloat16*)peer_grads[w]; - g_sum += __bfloat162float(src[offset]); - } - - float g = g_sum / world_size; - g = trunc_bf16(g); - - float m_val = __bfloat162float(m_in[i]); - float v_val = __bfloat162float(v_in[i]); - float orig_theta = __bfloat162float(theta_in[i]); - float theta_val = orig_theta; - - m_val = trunc_bf16(m_val * beta1); - m_val = trunc_bf16(m_val + g * (1.0f - beta1)); - - v_val = trunc_bf16(v_val * beta2); - v_val = trunc_bf16(v_val + g * g * (1.0f - beta2)); - - float m_hat = trunc_bf16(m_val / bc1); - float v_hat = trunc_bf16(v_val / bc2); - float denom = trunc_bf16(trunc_bf16(sqrtf(v_hat)) + eps); - - float step_term = trunc_bf16(m_hat / denom); - theta_val = trunc_bf16(theta_val - lr * step_term); - theta_val = trunc_bf16(theta_val - lr * weight_decay * orig_theta); - - m_out[i] = __float2bfloat16(m_val); - v_out[i] = __float2bfloat16(v_val); - theta_out[i] = __float2bfloat16(theta_val); - } -} - -void launch_all_gather( - torch::Tensor peer_ptrs, - int64_t p, - int world_size, - int rank -) { - int64_t total = p * world_size; - if (total == 0) return; - int threads = 512; - int blocks = std::max(1, std::min((int)((total + threads - 1) / threads), 65535)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - all_gather_kernel<<>>( - (const long long*)peer_ptrs.data_ptr(), - p, - world_size, - rank - ); -} - -void launch_rs_adamw( - torch::Tensor peer_grads_ptrs, - torch::Tensor theta_out, - torch::Tensor m_out, - torch::Tensor v_out, - torch::Tensor theta_in, - torch::Tensor m_in, - torch::Tensor v_in, - int64_t p, - int world_size, - int rank, - float lr, - float beta1, - float beta2, - float eps, - float weight_decay, - float bc1, - float bc2 -) { - if (p == 0) return; - int threads = 512; - int blocks = std::max(1, std::min((int)((p + threads - 1) / threads), 65535)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - rs_adamw_kernel<<>>( - (const long long*)peer_grads_ptrs.data_ptr(), - (__nv_bfloat16*)theta_out.data_ptr(), - (__nv_bfloat16*)m_out.data_ptr(), - (__nv_bfloat16*)v_out.data_ptr(), - (const __nv_bfloat16*)theta_in.data_ptr(), - (const __nv_bfloat16*)m_in.data_ptr(), - (const __nv_bfloat16*)v_in.data_ptr(), - p, - world_size, - rank, - lr, - beta1, - beta2, - eps, - weight_decay, - bc1, - bc2 - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_all_gather", &launch_all_gather, "Custom P2P all gather"); - m.def("launch_rs_adamw", &launch_rs_adamw, "Fused ReduceScatter and AdamW"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fsdp_step_e2e_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(p: int, world_size: int, dtype: torch.dtype, device: torch.device): - key = (p, world_size, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - full_flat_buf = symm_mem.empty(world_size * p, device=device, dtype=dtype) - full_flat_hdl = symm_mem.rendezvous(full_flat_buf, dist.group.WORLD) - full_flat_ptrs = torch.tensor(full_flat_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - sym_grad_buf = symm_mem.empty(world_size * p, device=device, dtype=dtype) - sym_grad_hdl = symm_mem.rendezvous(sym_grad_buf, dist.group.WORLD) - sym_grad_ptrs = torch.tensor(sym_grad_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (full_flat_buf, full_flat_hdl, full_flat_ptrs, sym_grad_buf, sym_grad_hdl, sym_grad_ptrs) - _symm_cache[key] = res - return res - -def solution( - X_local: Tensor, - y_local: Tensor, - flat_param_shard: Tensor, - param_shapes: Sequence[tuple[int, ...]], - exp_avg_shard: Tensor, - exp_avg_sq_shard: Tensor, - lr: float, - beta1: float, - beta2: float, - eps: float, - weight_decay: float, - step: int, -) -> tuple[Tensor, Tensor, Tensor]: - - assert dist.is_initialized(), "torch.distributed must be initialized" - assert step >= 1 - - world_size = dist.get_world_size() - rank = dist.get_rank() - - flat_param_shard = flat_param_shard.contiguous() - exp_avg_shard = exp_avg_shard.contiguous() - exp_avg_sq_shard = exp_avg_sq_shard.contiguous() - - p = flat_param_shard.numel() - dtype = flat_param_shard.dtype - device = flat_param_shard.device - - full_flat, full_flat_hdl, full_flat_ptrs, sym_grad, sym_grad_hdl, sym_grad_ptrs = \ - _get_symm_state(p, world_size, dtype, device) - - if rank == 0: - _get_ext() - dist.barrier() - - # 1. P2P All-Gather parameter chunks - full_flat[rank * p : (rank + 1) * p].copy_(flat_param_shard) - full_flat_hdl.barrier(channel=0) - - _get_ext().launch_all_gather(full_flat_ptrs, p, world_size, rank) - - # 2. PyTorch Forward & Backward Pass on unflattened views - templates = [torch.empty(shape, dtype=dtype, device=device) for shape in param_shapes] - params_f = _unflatten_dense_tensors(full_flat, templates) - params = [t.detach().requires_grad_(True) for t in params_f] - - h = F.relu(F.linear(X_local, params[0], params[1])) - out = F.linear(h, params[2], params[3]) - loss = F.mse_loss(out, y_local) - loss.backward() - - # 3. Share gradients across peers - flat_g = _flatten_dense_tensors([x.grad for x in params]) - sym_grad.copy_(flat_g) - sym_grad_hdl.barrier(channel=0) - - theta_out = torch.empty_like(flat_param_shard) - m_out = torch.empty_like(exp_avg_shard) - v_out = torch.empty_like(exp_avg_sq_shard) - - bc1 = 1.0 - math.pow(beta1, step) - bc2 = 1.0 - math.pow(beta2, step) - - # 4. Read gradient chunks, Reduce, and Step AdamW seamlessly - _get_ext().launch_rs_adamw( - sym_grad_ptrs, - theta_out, m_out, v_out, - flat_param_shard, exp_avg_shard, exp_avg_sq_shard, - p, world_size, rank, - float(lr), float(beta1), float(beta2), float(eps), float(weight_decay), float(bc1), float(bc2) - ) - - sym_grad_hdl.barrier(channel=1) - - return theta_out, m_out, v_out - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/49_fsdp_and_tp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/49_fsdp_and_tp_cuda.py deleted file mode 100755 index f6daaf0..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/49_fsdp_and_tp_cuda.py +++ /dev/null @@ -1,339 +0,0 @@ -from __future__ import annotations - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor -import torch.nn.functional as F - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// 1. FSDP Gather Kernels (Directly copying full weights from peer's UVA memory) -// --------------------------------------------------------------------------- -__global__ void gather_w1_w2_kernel( - const uint64_t* __restrict__ ptrs, - __nv_bfloat16* __restrict__ w1_full, - __nv_bfloat16* __restrict__ w2_full, - int n_fsdp, - int n_tp, - int tp_rank, - int64_t K -) { - int64_t total_elements = n_fsdp * K; - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - for (; tid < total_elements; tid += stride) { - int j = tid / K; - int64_t idx = tid % K; - int peer_rank = j * n_tp + tp_rank; // Ranks sharing the same tp_rank - const __nv_bfloat16* peer_buf = (const __nv_bfloat16*)ptrs[peer_rank]; - - w1_full[tid] = peer_buf[idx]; - w2_full[tid] = peer_buf[K + idx]; - } -} - -__global__ void gather_w3_kernel( - const uint64_t* __restrict__ ptrs, - __nv_bfloat16* __restrict__ w3_full, - int n_fsdp, - int n_tp, - int tp_rank, - int rows, - int cols -) { - int64_t K = (int64_t)rows * cols; - int64_t total_elements = n_fsdp * K; - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - for (; tid < total_elements; tid += stride) { - int j = tid / K; - int64_t idx = tid % K; - int peer_rank = j * n_tp + tp_rank; - - int r = idx / cols; - int c = idx % cols; - - const __nv_bfloat16* peer_buf = (const __nv_bfloat16*)ptrs[peer_rank]; - - // Strided copy since W3 is gathered along dim=1 - int64_t out_idx = (int64_t)r * (cols * n_fsdp) + j * cols + c; - w3_full[out_idx] = peer_buf[2 * K + idx]; - } -} - -// --------------------------------------------------------------------------- -// 2. Fused SwiGLU (z = silu(x1) * x2) -// --------------------------------------------------------------------------- -__global__ void swiglu_bf16x2_kernel( - const __nv_bfloat162* __restrict__ x1, - const __nv_bfloat162* __restrict__ x2, - __nv_bfloat162* __restrict__ z, - int64_t numel_2 -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - for (int64_t i = idx; i < numel_2; i += stride) { - float2 v1 = __bfloat1622float2(x1[i]); - float2 v2 = __bfloat1622float2(x2[i]); - - float sig_x = 1.0f / (1.0f + expf(-v1.x)); - float sig_y = 1.0f / (1.0f + expf(-v1.y)); - - float2 res; - res.x = v1.x * sig_x * v2.x; - res.y = v1.y * sig_y * v2.y; - - z[i] = __float22bfloat162_rn(res); - } -} - -__global__ void swiglu_odd_kernel( - const __nv_bfloat16* __restrict__ x1, - const __nv_bfloat16* __restrict__ x2, - __nv_bfloat16* __restrict__ z, - int64_t idx -) { - float val1 = __bfloat162float(x1[idx]); - float val2 = __bfloat162float(x2[idx]); - float sig = 1.0f / (1.0f + expf(-val1)); - z[idx] = __float2bfloat16(val1 * sig * val2); -} - -// --------------------------------------------------------------------------- -// 3. Tensor Parallel All-Reduce -// --------------------------------------------------------------------------- -__global__ void tp_allreduce_bf16x2_kernel( - const uint64_t* __restrict__ ptrs, - __nv_bfloat16* __restrict__ y_out, - int n_tp, - int n_fsdp, - int fsdp_rank, - int64_t numel -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - int64_t numel_2 = numel / 2; - - for (int64_t i = idx; i < numel_2; i += stride) { - float2 sum = {0.0f, 0.0f}; - for (int p = 0; p < n_tp; ++p) { - int peer_rank = fsdp_rank * n_tp + p; - const __nv_bfloat162* peer_buf = (const __nv_bfloat162*)ptrs[peer_rank]; - float2 val = __bfloat1622float2(peer_buf[i]); - sum.x += val.x; - sum.y += val.y; - } - ((__nv_bfloat162*)y_out)[i] = __float22bfloat162_rn(sum); - } - - if (idx == 0 && (numel % 2) != 0) { - int64_t last_idx = numel - 1; - float sum = 0.0f; - for (int p = 0; p < n_tp; ++p) { - int peer_rank = fsdp_rank * n_tp + p; - const __nv_bfloat16* peer_buf = (const __nv_bfloat16*)ptrs[peer_rank]; - sum += __bfloat162float(peer_buf[last_idx]); - } - y_out[last_idx] = __float2bfloat16(sum); - } -} - -// --------------------------------------------------------------------------- -// Host Bindings -// --------------------------------------------------------------------------- -void launch_gather_w1_w2( - torch::Tensor ptrs, torch::Tensor w1_full, torch::Tensor w2_full, - int n_fsdp, int n_tp, int tp_rank, int64_t K -) { - int threads = 256; - int blocks = (n_fsdp * K + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_w1_w2_kernel<<>>( - reinterpret_cast(ptrs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(w1_full.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(w2_full.data_ptr()), - n_fsdp, n_tp, tp_rank, K - ); -} - -void launch_gather_w3( - torch::Tensor ptrs, torch::Tensor w3_full, - int n_fsdp, int n_tp, int tp_rank, int rows, int cols -) { - int64_t K = (int64_t)rows * cols; - int threads = 256; - int blocks = (n_fsdp * K + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_w3_kernel<<>>( - reinterpret_cast(ptrs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(w3_full.data_ptr()), - n_fsdp, n_tp, tp_rank, rows, cols - ); -} - -void launch_swiglu(torch::Tensor x1, torch::Tensor x2, torch::Tensor z, int64_t numel) { - int threads = 256; - int64_t numel_2 = numel / 2; - if (numel_2 > 0) { - int blocks = (numel_2 + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - swiglu_bf16x2_kernel<<>>( - reinterpret_cast(x1.data_ptr()), - reinterpret_cast(x2.data_ptr()), - reinterpret_cast<__nv_bfloat162*>(z.data_ptr()), - numel_2 - ); - } - if (numel % 2 != 0) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - swiglu_odd_kernel<<<1, 1, 0, stream>>>( - reinterpret_cast(x1.data_ptr()), - reinterpret_cast(x2.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(z.data_ptr()), - numel - 1 - ); - } -} - -void launch_tp_allreduce( - torch::Tensor ptrs, torch::Tensor y_out, - int n_tp, int n_fsdp, int fsdp_rank, int64_t numel -) { - int threads = 256; - int blocks = (numel / 2 + threads - 1) / threads; - if (blocks == 0 && numel > 0) blocks = 1; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (numel > 0) { - tp_allreduce_bf16x2_kernel<<>>( - reinterpret_cast(ptrs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(y_out.data_ptr()), - n_tp, n_fsdp, fsdp_rank, numel - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather_w1_w2", &launch_gather_w1_w2, "Gather W1 and W2 via P2P"); - m.def("launch_gather_w3", &launch_gather_w3, "Gather W3 via P2P"); - m.def("launch_swiglu", &launch_swiglu, "Fused SwiGLU bf16"); - m.def("launch_tp_allreduce", &launch_tp_allreduce, "TP AllReduce via P2P"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fsdp_tp_fused_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(key, shape, dtype, device): - if key in _symm_cache: - return _symm_cache[key] - buf = symm_mem.empty(shape, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - _symm_cache[key] = (buf, hdl, ptrs) - return buf, hdl, ptrs - -_local_cache = {} -def _get_local_buffer(key, shape, dtype, device): - if key in _local_cache: - return _local_cache[key] - buf = torch.empty(shape, dtype=dtype, device=device) - _local_cache[key] = buf - return buf - -_stream2 = None -def _get_stream2(): - global _stream2 - if _stream2 is None: - _stream2 = torch.cuda.Stream() - return _stream2 - - -@torch.no_grad() -def solution( - x_local: Tensor, - W1_shard: Tensor, - W2_shard: Tensor, - W3_shard: Tensor, - n_tp: int, - n_fsdp: int, -) -> Tensor: - rank = dist.get_rank() - fsdp_rank = rank // n_tp - tp_rank = rank % n_tp - - d_fsdp, d_ff_tp = W1_shard.shape - K = d_fsdp * d_ff_tp - device = x_local.device - dtype = x_local.dtype - ext = _get_ext() - - # 1. Acquire symmetric buffer for FSDP weights and copy our local shards over - shards_key = ("shards", 3 * K, dtype, device) - shards_symm, hdl_shards, ptrs_shards = _get_symm_state(shards_key, [3 * K], dtype, device) - - shards_symm[0:K].copy_(W1_shard.view(-1)) - shards_symm[K:2*K].copy_(W2_shard.view(-1)) - shards_symm[2*K:3*K].copy_(W3_shard.view(-1)) - - # Barrier ensures all peers have flushed their local weights to Symmetric Memory - hdl_shards.barrier(channel=0) - - # 2. Reconstruct locally missing chunks by grabbing them off peer devices - W1_full = _get_local_buffer(("w1", d_fsdp * n_fsdp, d_ff_tp), (d_fsdp * n_fsdp, d_ff_tp), dtype, device) - W2_full = _get_local_buffer(("w2", d_fsdp * n_fsdp, d_ff_tp), (d_fsdp * n_fsdp, d_ff_tp), dtype, device) - W3_full = _get_local_buffer(("w3", d_ff_tp, d_fsdp * n_fsdp), (d_ff_tp, d_fsdp * n_fsdp), dtype, device) - - # Pull W1 and W2 onto Default Stream to unlock first matmuls - ext.launch_gather_w1_w2(ptrs_shards, W1_full, W2_full, n_fsdp, n_tp, tp_rank, K) - - # Overlap: Schedule W3's Gather independently on background stream - stream2 = _get_stream2() - stream2.wait_stream(torch.cuda.current_stream()) # Stream 2 awaits the Barrier flush prior to pulling - with torch.cuda.stream(stream2): - ext.launch_gather_w3(ptrs_shards, W3_full, n_fsdp, n_tp, tp_rank, d_ff_tp, d_fsdp) - - # Overlap: Compute hidden states (x1, x2) and execute customized SwiGLU alongside W3's comm - x1 = torch.mm(x_local, W1_full) - x2 = torch.mm(x_local, W2_full) - - z = _get_local_buffer(("z", x1.shape[0], x1.shape[1]), x1.shape, dtype, device) - ext.launch_swiglu(x1, x2, z, x1.numel()) - - # Re-sync prior to resolving W3 - torch.cuda.current_stream().wait_stream(stream2) - y_partial = torch.mm(z, W3_full) - - # 3. Complete chunk reduction using an identical P2P paradigm across the TP ranks - y_numel = y_partial.numel() - y_key = ("y", y_partial.shape, dtype, device) - y_symm, hdl_y, ptrs_y = _get_symm_state(y_key, y_partial.shape, dtype, device) - y_out = torch.empty_like(y_partial) - - y_symm.copy_(y_partial.view(-1)) - hdl_y.barrier(channel=0) - - ext.launch_tp_allreduce(ptrs_y, y_out, n_tp, n_fsdp, fsdp_rank, y_numel) - - return y_out - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_cuda.py deleted file mode 100755 index fcd2579..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/4_reduce_cuda.py +++ /dev/null @@ -1,192 +0,0 @@ -import math -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Hopper NVSwitch Multimem Reduce (dst rank only) -// --------------------------------------------------------------------------- -__global__ void multimem_reduce_bf16_kernel( - uint64_t multicast_base, - __nv_bfloat16* __restrict__ out, - int64_t numel_128 -) { - const int num_programs = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int64_t idx = tid; idx < numel_128; idx += num_programs) { - uint64_t* ptr = reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - - // Fetch and reduce 128-bits (8 x bfloat16) across all ranks in hardware - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(x), "=r"(y), "=r"(z), "=r"(w) - : "l"(ptr) - : "memory"); - - // 128-bit vectorized store to the destination buffer - uint4* out_ptr = reinterpret_cast(out) + idx; - *out_ptr = make_uint4(x, y, z, w); - } -} - -// --------------------------------------------------------------------------- -// UVA Peer-Pointer Fallback Reduce (dst rank only) -// --------------------------------------------------------------------------- -template -__global__ void reduce_generic_kernel( - const long long* __restrict__ ptrs, - T* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const T* src = reinterpret_cast(ptrs[r]); - sum += static_cast(src[idx]); - } - out[idx] = static_cast(sum); - } -} - -// --------------------------------------------------------------------------- -// Launchers -// --------------------------------------------------------------------------- -void launch_multimem_reduce_bf16( - uint64_t multicast_ptr, - torch::Tensor out, - int64_t numel_128 -) { - if (numel_128 == 0) return; - int threads = 512; - int blocks = (numel_128 + threads - 1) / threads; - if (blocks > 1024) blocks = 1024; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_reduce_bf16_kernel<<>>( - multicast_ptr, - reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), - numel_128 - ); -} - -void launch_reduce( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t n -) { - if (n == 0) return; - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = reinterpret_cast(ptrs_tensor.data_ptr()); - - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 1024) blocks = 1024; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, out.scalar_type(), "reduce_kernel", ([&] { - reduce_generic_kernel<<>>( - d_ptrs, - out.data_ptr(), - world_size, - n - ); - })); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_reduce_bf16", &launch_multimem_reduce_bf16, "Multimem hardware reduce to dest"); - m.def("launch_reduce", &launch_reduce, "Custom P2P UVA reduce fallback"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("reduce_cuda_opt_ext", CUDA_SRC) - return _ext - -_resource_cache = {} -def _get_resources(shape, dtype, device): - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - n = math.prod(shape) - - # Pad allocations to multiples of 8 elements for pure 128-bit vectorization in BF16 - pad_n = (n + 7) & ~7 if dtype == torch.bfloat16 else n - - buf = symm_mem.empty(pad_n, device=device, dtype=dtype) - buf.zero_() # Zero padding elements to safely accumulate +0.0 during multimem tail fetches - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - out_pad = torch.empty(pad_n, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, out_pad, ptrs_tensor) - _resource_cache[key] = res - return res - - -@torch.no_grad() -def solution( - tensor: torch.Tensor, - dst: int = 0, -) -> torch.Tensor: - """ - Optimized device-side collective to replace dist.reduce(). - Favors NVSwitch multimem on Hopper or optimized peer UVA reads. - """ - if not dist.is_initialized(): - return tensor.clone() - - input_tensor = tensor.contiguous() - n = input_tensor.numel() - dtype = input_tensor.dtype - device = input_tensor.device - rank = dist.get_rank() - - # Pre-registered resources avoiding allocation on the hot path - buf, hdl, out_pad, ptrs_tensor = _get_resources(input_tensor.shape, dtype, device) - - # 1. Fill registered symmetric buffer (tail elements safely untouched and remain zeroed) - buf[:n].copy_(input_tensor.flatten()) - - # 2. Synchronize all ranks before destination reads - hdl.barrier(channel=0) - - # 3. Pull and reduce via Switch / NVLink (Executed solely on dst rank) - multicast_ptr = getattr(hdl, 'multicast_ptr', 0) - - if multicast_ptr != 0 and dtype == torch.bfloat16: - if rank == dst: - numel_128 = out_pad.numel() // 8 - _get_ext().launch_multimem_reduce_bf16(multicast_ptr, out_pad, numel_128) - else: - if rank == dst: - _get_ext().launch_reduce(ptrs_tensor, out_pad, n) - - # 4. Enforce buffer lifespan: ensure dst completes reads before next collective overwrites buf - hdl.barrier(channel=0) - - # 5. Result isolation - if rank == dst: - return out_pad[:n].reshape(input_tensor.shape).clone() - else: - return input_tensor.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/50_moe_ep_balanced_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/50_moe_ep_balanced_cuda.py deleted file mode 100755 index ee2ae63..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/50_moe_ep_balanced_cuda.py +++ /dev/null @@ -1,458 +0,0 @@ -""" -Strategy: -1. **Direct P2P Dispatch/Combine**: Fuses the `permute`, `all_to_all`, and `sort` into a single Push-based direct-memory-access CUDA kernel. Fuses the reverse `all_to_all` and `unpermute` into a single Pull-based CUDA kernel. -2. **Symmetric Memory via UVA**: Tokens and gradients are exchanged by writing/reading directly to/from symmetric buffers (`symm_mem.rendezvous`) over NVLink. This entirely bypasses NCCL host launch overheads and intermediate tensor copies. -3. **Zero Atomics Backward**: By generating token-to-expert mapping offsets dynamically on the forward pass, the backward passes for both dispatch and combine are entirely conflict-free and require zero atomics. -4. **Hardware Barriers**: Compute and communication are seamlessly ordered using fast device-side barriers (`hdl.barrier(channel=0)`), avoiding CPU synchronization stalls. -""" - -from typing import Optional -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Helper to convert to float for accumulation -template __device__ __forceinline__ float to_float(T val); -template <> __device__ __forceinline__ float to_float(float val) { return val; } -template <> __device__ __forceinline__ float to_float(double val) { return static_cast(val); } -template <> __device__ __forceinline__ float to_float(at::Half val) { return __half2float(val); } -template <> __device__ __forceinline__ float to_float(at::BFloat16 val) { return __bfloat162float(val); } - -// Helper to convert from float -template __device__ __forceinline__ T from_float(float val); -template <> __device__ __forceinline__ float from_float(float val) { return val; } -template <> __device__ __forceinline__ double from_float(float val) { return static_cast(val); } -template <> __device__ __forceinline__ at::Half from_float(float val) { return __float2half(val); } -template <> __device__ __forceinline__ at::BFloat16 from_float(float val) { return __float2bfloat16(val); } - -template -__global__ void dispatch_forward_kernel( - const scalar_t* __restrict__ hidden_states, - const int64_t* __restrict__ selected_experts, - int* __restrict__ expert_counters, - int* __restrict__ token_local_offsets, - const int* __restrict__ recv_offsets, - const uint64_t* __restrict__ remote_ptrs, - int N, int K, int H -) { - int nk = blockIdx.x; - int i = nk / K; - int e = selected_experts[nk]; - - __shared__ int shared_local_offset; - if (threadIdx.x == 0) { - shared_local_offset = atomicAdd(&expert_counters[e], 1); - token_local_offsets[nk] = shared_local_offset; - } - __syncthreads(); - - int local_offset = shared_local_offset; - int remote_offset = recv_offsets[e] + local_offset; - scalar_t* remote_buf = reinterpret_cast(remote_ptrs[e]); - - for (int h = threadIdx.x; h < H; h += blockDim.x) { - remote_buf[remote_offset * H + h] = hidden_states[i * H + h]; - } -} - -template -__global__ void dispatch_backward_kernel( - scalar_t* __restrict__ grad_hidden_states, - const int64_t* __restrict__ selected_experts, - const int* __restrict__ token_local_offsets, - const int* __restrict__ recv_offsets, - const uint64_t* __restrict__ remote_ptrs, - int N, int K, int H -) { - int i = blockIdx.x; - for (int h = threadIdx.x; h < H; h += blockDim.x) { - float sum = 0.0f; - for (int k = 0; k < K; ++k) { - int e = selected_experts[i * K + k]; - int local_offset = token_local_offsets[i * K + k]; - int remote_offset = recv_offsets[e] + local_offset; - const scalar_t* remote_buf = reinterpret_cast(remote_ptrs[e]); - sum += to_float(remote_buf[remote_offset * H + h]); - } - grad_hidden_states[i * H + h] = from_float(sum); - } -} - -template -__global__ void combine_forward_kernel( - scalar_t* __restrict__ combined_output, - const int64_t* __restrict__ selected_experts, - const scalar_t* __restrict__ routing_weights, - const int* __restrict__ token_local_offsets, - const int* __restrict__ recv_offsets, - const uint64_t* __restrict__ remote_ptrs, - int N, int K, int H -) { - int i = blockIdx.x; - for (int h = threadIdx.x; h < H; h += blockDim.x) { - float sum = 0.0f; - for (int k = 0; k < K; ++k) { - int e = selected_experts[i * K + k]; - float w = to_float(routing_weights[i * K + k]); - int local_offset = token_local_offsets[i * K + k]; - int remote_offset = recv_offsets[e] + local_offset; - const scalar_t* remote_buf = reinterpret_cast(remote_ptrs[e]); - sum += w * to_float(remote_buf[remote_offset * H + h]); - } - combined_output[i * H + h] = from_float(sum); - } -} - -template -__global__ void combine_backward_kernel( - const scalar_t* __restrict__ grad_combined_output, - const int64_t* __restrict__ selected_experts, - const scalar_t* __restrict__ routing_weights, - const int* __restrict__ token_local_offsets, - const int* __restrict__ recv_offsets, - const uint64_t* __restrict__ remote_grad_ptrs, - const uint64_t* __restrict__ remote_expert_out_ptrs, - scalar_t* __restrict__ grad_weights, - int N, int K, int H -) { - int nk = blockIdx.x; - int i = nk / K; - - int e = selected_experts[nk]; - float w = to_float(routing_weights[nk]); - int local_offset = token_local_offsets[nk]; - int remote_offset = recv_offsets[e] + local_offset; - - scalar_t* remote_grad_buf = reinterpret_cast(remote_grad_ptrs[e]); - const scalar_t* remote_expert_out_buf = reinterpret_cast(remote_expert_out_ptrs[e]); - - float dot_product = 0.0f; - for (int h = threadIdx.x; h < H; h += blockDim.x) { - float grad_out = to_float(grad_combined_output[i * H + h]); - float expert_out = to_float(remote_expert_out_buf[remote_offset * H + h]); - - remote_grad_buf[remote_offset * H + h] = from_float(grad_out * w); - dot_product += grad_out * expert_out; - } - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - dot_product += __shfl_down_sync(0xffffffff, dot_product, offset); - } - if (threadIdx.x == 0) { - grad_weights[nk] = from_float(dot_product); - } -} - -void launch_dispatch_forward( - torch::Tensor hidden_states, - torch::Tensor selected_experts, - torch::Tensor expert_counters, - torch::Tensor token_local_offsets, - torch::Tensor recv_offsets, - torch::Tensor remote_ptrs, - int N, int K, int H -) { - int threads = std::min(H, 1024); - int blocks = N * K; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, hidden_states.scalar_type(), "dispatch_forward", ([&] { - dispatch_forward_kernel<<>>( - hidden_states.data_ptr(), - selected_experts.data_ptr(), - expert_counters.data_ptr(), - token_local_offsets.data_ptr(), - recv_offsets.data_ptr(), - reinterpret_cast(remote_ptrs.data_ptr()), - N, K, H - ); - })); -} - -void launch_dispatch_backward( - torch::Tensor grad_hidden_states, - torch::Tensor selected_experts, - torch::Tensor token_local_offsets, - torch::Tensor recv_offsets, - torch::Tensor remote_ptrs, - int N, int K, int H -) { - int threads = std::min(H, 1024); - int blocks = N; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, grad_hidden_states.scalar_type(), "dispatch_backward", ([&] { - dispatch_backward_kernel<<>>( - grad_hidden_states.data_ptr(), - selected_experts.data_ptr(), - token_local_offsets.data_ptr(), - recv_offsets.data_ptr(), - reinterpret_cast(remote_ptrs.data_ptr()), - N, K, H - ); - })); -} - -void launch_combine_forward( - torch::Tensor combined_output, - torch::Tensor selected_experts, - torch::Tensor routing_weights, - torch::Tensor token_local_offsets, - torch::Tensor recv_offsets, - torch::Tensor remote_ptrs, - int N, int K, int H -) { - int threads = std::min(H, 1024); - int blocks = N; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, combined_output.scalar_type(), "combine_forward", ([&] { - combine_forward_kernel<<>>( - combined_output.data_ptr(), - selected_experts.data_ptr(), - routing_weights.data_ptr(), - token_local_offsets.data_ptr(), - recv_offsets.data_ptr(), - reinterpret_cast(remote_ptrs.data_ptr()), - N, K, H - ); - })); -} - -void launch_combine_backward( - torch::Tensor grad_combined_output, - torch::Tensor selected_experts, - torch::Tensor routing_weights, - torch::Tensor token_local_offsets, - torch::Tensor recv_offsets, - torch::Tensor remote_grad_ptrs, - torch::Tensor remote_expert_out_ptrs, - torch::Tensor grad_weights, - int N, int K, int H -) { - int threads = 32; - int blocks = N * K; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, grad_combined_output.scalar_type(), "combine_backward", ([&] { - combine_backward_kernel<<>>( - grad_combined_output.data_ptr(), - selected_experts.data_ptr(), - routing_weights.data_ptr(), - token_local_offsets.data_ptr(), - recv_offsets.data_ptr(), - reinterpret_cast(remote_grad_ptrs.data_ptr()), - reinterpret_cast(remote_expert_out_ptrs.data_ptr()), - grad_weights.data_ptr(), - N, K, H - ); - })); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_forward", &launch_dispatch_forward); - m.def("dispatch_backward", &launch_dispatch_backward); - m.def("combine_forward", &launch_combine_forward); - m.def("combine_backward", &launch_combine_backward); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_moe_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def get_symm_buffer(name: str, shape: tuple, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (name, shape, dtype, device, group) - if key not in _symm_cache: - buf = symm_mem.empty(shape, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - _symm_cache[key] = (buf, hdl, ptrs) - return _symm_cache[key] - -class FusedMoEDispatch(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, selected_experts, recv_offsets, M, dispatch_buf, dispatch_ptrs, group): - ctx.group = group - N, H = hidden_states.shape - K = selected_experts.shape[1] - W = recv_offsets.shape[0] - - expert_counters = torch.zeros(W, dtype=torch.int32, device=hidden_states.device) - token_local_offsets = torch.empty_like(selected_experts, dtype=torch.int32) - - _, dispatch_hdl, _ = get_symm_buffer("dispatch", dispatch_buf.shape, hidden_states.dtype, hidden_states.device, group) - dispatch_hdl.barrier(channel=0) - - _get_ext().dispatch_forward( - hidden_states, selected_experts, expert_counters, token_local_offsets, - recv_offsets, dispatch_ptrs, N, K, H - ) - - dispatch_hdl.barrier(channel=0) - expert_inputs = dispatch_buf[:M].clone() - - ctx.save_for_backward(selected_experts, token_local_offsets, recv_offsets, dispatch_ptrs) - ctx.N, ctx.K, ctx.H, ctx.W = N, K, H, W - ctx.mark_non_differentiable(token_local_offsets) - return expert_inputs, token_local_offsets - - @staticmethod - def backward(ctx, grad_expert_inputs, grad_token_local_offsets): - selected_experts, token_local_offsets, recv_offsets, dispatch_ptrs = ctx.saved_tensors - - grad_dispatch_buf, grad_dispatch_hdl, grad_dispatch_ptrs = get_symm_buffer( - "grad_dispatch", (ctx.W * ctx.N * ctx.K, ctx.H), - grad_expert_inputs.dtype, grad_expert_inputs.device, ctx.group - ) - grad_dispatch_buf[:grad_expert_inputs.shape[0]].copy_(grad_expert_inputs) - - grad_dispatch_hdl.barrier(channel=0) - - grad_hidden_states = torch.empty((ctx.N, ctx.H), dtype=grad_expert_inputs.dtype, device=grad_expert_inputs.device) - _get_ext().dispatch_backward( - grad_hidden_states, selected_experts, token_local_offsets, - recv_offsets, grad_dispatch_ptrs, ctx.N, ctx.K, ctx.H - ) - - grad_dispatch_hdl.barrier(channel=0) - return grad_hidden_states, None, None, None, None, None, None - -class FusedMoECombine(torch.autograd.Function): - @staticmethod - def forward(ctx, expert_outputs, selected_experts, routing_weights, token_local_offsets, recv_offsets, combine_buf, combine_ptrs, group): - ctx.group = group - N, K = selected_experts.shape - M, H = expert_outputs.shape - - combine_buf[:M].copy_(expert_outputs) - - _, combine_hdl, _ = get_symm_buffer("combine", combine_buf.shape, expert_outputs.dtype, expert_outputs.device, group) - combine_hdl.barrier(channel=0) - - combined_output = torch.empty((N, H), dtype=expert_outputs.dtype, device=expert_outputs.device) - _get_ext().combine_forward( - combined_output, selected_experts, routing_weights, token_local_offsets, - recv_offsets, combine_ptrs, N, K, H - ) - - combine_hdl.barrier(channel=0) - - ctx.save_for_backward(expert_outputs, selected_experts, routing_weights, token_local_offsets, recv_offsets) - ctx.N, ctx.K, ctx.H, ctx.M, ctx.W = N, K, H, M, recv_offsets.shape[0] - return combined_output - - @staticmethod - def backward(ctx, grad_combined_output): - expert_outputs, selected_experts, routing_weights, token_local_offsets, recv_offsets = ctx.saved_tensors - MAX_TOKENS = ctx.W * ctx.N * ctx.K - - combine_bwd_buf, combine_bwd_hdl, combine_bwd_ptrs = get_symm_buffer( - "combine_bwd_expert_out", (MAX_TOKENS, ctx.H), - expert_outputs.dtype, expert_outputs.device, ctx.group - ) - combine_bwd_buf[:ctx.M].copy_(expert_outputs) - - grad_combine_buf, grad_combine_hdl, grad_combine_ptrs = get_symm_buffer( - "grad_combine", (MAX_TOKENS, ctx.H), - grad_combined_output.dtype, grad_combined_output.device, ctx.group - ) - grad_combine_buf[:ctx.M].zero_() - - combine_bwd_hdl.barrier(channel=0) - grad_combine_hdl.barrier(channel=0) - - grad_weights = torch.empty_like(routing_weights) - _get_ext().combine_backward( - grad_combined_output, selected_experts, routing_weights, - token_local_offsets, recv_offsets, grad_combine_ptrs, - combine_bwd_ptrs, grad_weights, ctx.N, ctx.K, ctx.H - ) - - grad_combine_hdl.barrier(channel=0) - grad_expert_outputs = grad_combine_buf[:ctx.M].clone() - return grad_expert_outputs, None, grad_weights, None, None, None, None, None - -def expert_forward( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, -) -> torch.Tensor: - gate = torch.nn.functional.silu(gate_proj(x)) - up = up_proj(x) - return down_proj(gate * up) - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - if rank == 0: - _get_ext() - dist.barrier(group=group) - _get_ext() - - hidden_dim = hidden_states.size(-1) - original_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_dim) - N, H = hidden_states.shape - K = top_k - W = world_size - MAX_TOKENS = W * N * K - - router_logits = torch.nn.functional.linear(hidden_states, gate_weight, gate_bias) - routing_weights, selected_experts = torch.topk( - torch.softmax(router_logits, dim=-1), top_k, dim=-1 - ) - - send_counts = torch.bincount(selected_experts.view(-1), minlength=W).to(torch.int32) - counts_matrix = torch.empty((W, W), dtype=torch.int32, device=hidden_states.device) - dist.all_gather_into_tensor(counts_matrix.view(-1), send_counts, group=group) - - recv_offsets_matrix = counts_matrix.cumsum(dim=0) - counts_matrix - recv_offsets = recv_offsets_matrix[rank].contiguous() - M = counts_matrix[:, rank].sum().item() - - dispatch_buf, _, dispatch_ptrs = get_symm_buffer( - "dispatch", (MAX_TOKENS, H), hidden_states.dtype, hidden_states.device, group - ) - combine_buf, _, combine_ptrs = get_symm_buffer( - "combine", (MAX_TOKENS, H), hidden_states.dtype, hidden_states.device, group - ) - - expert_inputs, token_local_offsets = FusedMoEDispatch.apply( - hidden_states, selected_experts, recv_offsets, M, dispatch_buf, dispatch_ptrs, group - ) - - expert_outputs = expert_forward(expert_inputs, gate_proj, up_proj, down_proj) - - out = FusedMoECombine.apply( - expert_outputs, selected_experts, routing_weights, token_local_offsets, - recv_offsets, combine_buf, combine_ptrs, group - ) - - return out.reshape(original_shape) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/51_moe_ep_wide_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/51_moe_ep_wide_cuda.py deleted file mode 100755 index df08d90..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/51_moe_ep_wide_cuda.py +++ /dev/null @@ -1,491 +0,0 @@ -""" -Strategy: -1. Replace NCCL AllToAll and AllGather with custom UVA memory movement kernels over symmetric memory buffers. -2. Directly compute the global layout offsets for each expert. This allows ranks to scatter tokens *directly* to their sorted, final destination in peers' memory, completely removing the need for intermediary sorting (e.g., `_sort_chunks_by_idxs`). -3. Encapsulate the scatter/gather data movements in custom `torch.autograd.Function`s (`MoEScatter`, `MoEGather`), ensuring smooth and direct gradient propagation using identical reverse UVA paths. -""" - -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void push_chunks_kernel( - const int8_t* __restrict__ src, - const int64_t* __restrict__ src_offsets, - const int64_t* __restrict__ dst_ptrs, - const int* __restrict__ chunk_sizes, - int hidden_dim_bytes, - int num_chunks, - int blocks_per_chunk -) { - int chunk_idx = blockIdx.x / blocks_per_chunk; - int block_offset = blockIdx.x % blocks_per_chunk; - if (chunk_idx >= num_chunks) return; - - int size = chunk_sizes[chunk_idx]; - if (size == 0) return; - - const int8_t* src_chunk = src + src_offsets[chunk_idx] * hidden_dim_bytes; - int8_t* dst_chunk = reinterpret_cast(dst_ptrs[chunk_idx]); - - int total_bytes = size * hidden_dim_bytes; - int total_vec = total_bytes / 16; - - const float4* src_vec = reinterpret_cast(src_chunk); - float4* dst_vec = reinterpret_cast(dst_chunk); - - for (int i = block_offset * blockDim.x + threadIdx.x; i < total_vec; i += blocks_per_chunk * blockDim.x) { - dst_vec[i] = src_vec[i]; - } - - if (block_offset == 0 && threadIdx.x == 0) { - for(int i = total_vec * 16; i < total_bytes; ++i) { - dst_chunk[i] = src_chunk[i]; - } - } -} - -__global__ void gather_counts_kernel( - const int64_t* __restrict__ peer_ptrs, - int* __restrict__ out, - int world_size, - int num_experts -) { - int r = blockIdx.x; - if (r >= world_size) return; - const int* peer_count = reinterpret_cast(peer_ptrs[r]); - for (int e = threadIdx.x; e < num_experts; e += blockDim.x) { - out[r * num_experts + e] = peer_count[e]; - } -} - -void launch_push_chunks( - torch::Tensor src, - torch::Tensor src_offsets, - torch::Tensor dst_ptrs, - torch::Tensor chunk_sizes, - int hidden_dim_bytes -) { - int num_chunks = chunk_sizes.size(0); - int blocks_per_chunk = 16; - int total_blocks = num_chunks * blocks_per_chunk; - int threads = 256; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - push_chunks_kernel<<>>( - reinterpret_cast(src.data_ptr()), - src_offsets.data_ptr(), - dst_ptrs.data_ptr(), - chunk_sizes.data_ptr(), - hidden_dim_bytes, - num_chunks, - blocks_per_chunk - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -void launch_gather_counts( - torch::Tensor peer_ptrs, - torch::Tensor out, - int world_size, - int num_experts -) { - int threads = std::min(num_experts, 1024); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_counts_kernel<<>>( - peer_ptrs.data_ptr(), - out.data_ptr(), - world_size, - num_experts - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push_chunks", &launch_push_chunks, "Push chunks via UVA"); - m.def("launch_gather_counts", &launch_gather_counts, "Gather counts via UVA"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("moe_uva_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_buffers(device, hidden_dim, world_size, num_experts, dtype): - key = (device, dtype) - if key in _symm_cache: - return _symm_cache[key] - - MAX_TOKENS = 16384 # Abundant safety buffer for MoE tokens - - buf_counts = symm_mem.empty((num_experts,), dtype=torch.int32, device=device) - hdl_counts = symm_mem.rendezvous(buf_counts, dist.group.WORLD) - - buf_global_permuted = symm_mem.empty((MAX_TOKENS, hidden_dim), dtype=dtype, device=device) - hdl_global_permuted = symm_mem.rendezvous(buf_global_permuted, dist.group.WORLD) - - buf_grad_local = symm_mem.empty((MAX_TOKENS, hidden_dim), dtype=dtype, device=device) - hdl_grad_local = symm_mem.rendezvous(buf_grad_local, dist.group.WORLD) - - buf_unpermute = symm_mem.empty((MAX_TOKENS, hidden_dim), dtype=dtype, device=device) - hdl_unpermute = symm_mem.rendezvous(buf_unpermute, dist.group.WORLD) - - buf_grad_expert = symm_mem.empty((MAX_TOKENS, hidden_dim), dtype=dtype, device=device) - hdl_grad_expert = symm_mem.rendezvous(buf_grad_expert, dist.group.WORLD) - - res = { - "buf_counts": buf_counts, - "hdl_counts": hdl_counts, - "buf_global_permuted": buf_global_permuted, - "hdl_global_permuted": hdl_global_permuted, - "buf_grad_local": buf_grad_local, - "hdl_grad_local": hdl_grad_local, - "buf_unpermute": buf_unpermute, - "hdl_unpermute": hdl_unpermute, - "buf_grad_expert": buf_grad_expert, - "hdl_grad_expert": hdl_grad_expert, - } - _symm_cache[key] = res - return res - - -# ----- UVA Offset Calculation Utils ----- - -def compute_forward_scatter_args(global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu): - num_local_experts = num_experts // world_size - src_offsets = torch.zeros(num_experts, dtype=torch.int64) - dst_ptrs = torch.zeros(num_experts, dtype=torch.int64) - - gc = global_counts_cpu.tolist() - current_src_offset = 0 - - for e in range(num_experts): - src_offsets[e] = current_src_offset - current_src_offset += gc[rank][e] - - dest_rank = e // num_local_experts - expert_base = 0 - for e_prime in range(dest_rank * num_local_experts, e): - expert_base += sum(gc[r][e_prime] for r in range(world_size)) - - rank_offset = sum(gc[r][e] for r in range(rank)) - write_ptr = expert_base + rank_offset - dst_ptrs[e] = symm_ptrs_cpu[dest_rank] + write_ptr * hidden_dim * element_size - - return src_offsets, dst_ptrs - -def compute_backward_scatter_args(global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu): - num_local_experts = num_experts // world_size - num_chunks = num_local_experts * world_size - src_offsets = torch.zeros(num_chunks, dtype=torch.int64) - dst_ptrs = torch.zeros(num_chunks, dtype=torch.int64) - chunk_sizes = torch.zeros(num_chunks, dtype=torch.int32) - - gc = global_counts_cpu.tolist() - current_src_offset = 0 - - for local_e in range(num_local_experts): - e = rank * num_local_experts + local_e - for dest_rank in range(world_size): - c = local_e * world_size + dest_rank - size = gc[dest_rank][e] - chunk_sizes[c] = size - src_offsets[c] = current_src_offset - current_src_offset += size - offset = sum(gc[dest_rank][ep] for ep in range(e)) - dst_ptrs[c] = symm_ptrs_cpu[dest_rank] + offset * hidden_dim * element_size - - return src_offsets, dst_ptrs, chunk_sizes - -def compute_forward_gather_args(global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu): - num_local_experts = num_experts // world_size - num_chunks = num_local_experts * world_size - src_offsets = torch.zeros(num_chunks, dtype=torch.int64) - dst_ptrs = torch.zeros(num_chunks, dtype=torch.int64) - chunk_sizes = torch.zeros(num_chunks, dtype=torch.int32) - - gc = global_counts_cpu.tolist() - current_src_offset = 0 - - for local_e in range(num_local_experts): - e = rank * num_local_experts + local_e - for dest_rank in range(world_size): - c = local_e * world_size + dest_rank - size = gc[dest_rank][e] - chunk_sizes[c] = size - src_offsets[c] = current_src_offset - current_src_offset += size - write_ptr = sum(gc[dest_rank][ep] for ep in range(e)) - dst_ptrs[c] = symm_ptrs_cpu[dest_rank] + write_ptr * hidden_dim * element_size - - return src_offsets, dst_ptrs, chunk_sizes - -def compute_backward_gather_args(global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu): - num_local_experts = num_experts // world_size - src_offsets = torch.zeros(num_experts, dtype=torch.int64) - dst_ptrs = torch.zeros(num_experts, dtype=torch.int64) - chunk_sizes = torch.zeros(num_experts, dtype=torch.int32) - - gc = global_counts_cpu.tolist() - current_src_offset = 0 - - for e in range(num_experts): - size = gc[rank][e] - chunk_sizes[e] = size - src_offsets[e] = current_src_offset - current_src_offset += size - - dest_rank = e // num_local_experts - local_e = e % num_local_experts - - offset = 0 - for le in range(local_e): - ep = dest_rank * num_local_experts + le - offset += sum(gc[r][ep] for r in range(world_size)) - - offset += sum(gc[r][e] for r in range(rank)) - dst_ptrs[e] = symm_ptrs_cpu[dest_rank] + offset * hidden_dim * element_size - - return src_offsets, dst_ptrs, chunk_sizes - - -# ----- Autograd Collectives ----- - -class MoEScatter(torch.autograd.Function): - @staticmethod - def forward(ctx, local_permuted, global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu, group, symm_grad_local_ptrs_cpu): - src_offsets, dst_ptrs = compute_forward_scatter_args( - global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu) - chunk_sizes = global_counts_cpu[rank] - - dist.barrier(group=group) - _get_ext().launch_push_chunks( - local_permuted, - src_offsets.to(local_permuted.device), - dst_ptrs.to(local_permuted.device), - chunk_sizes.to(local_permuted.device), - hidden_dim * element_size - ) - dist.barrier(group=group) - - num_local_experts = num_experts // world_size - total_received = global_counts_cpu[:, rank * num_local_experts : (rank + 1) * num_local_experts].sum().item() - - ctx.save_for_backward(global_counts_cpu) - ctx.rank = rank - ctx.world_size = world_size - ctx.num_experts = num_experts - ctx.hidden_dim = hidden_dim - ctx.element_size = element_size - ctx.symm_grad_local_ptrs_cpu = symm_grad_local_ptrs_cpu - ctx.group = group - ctx.dtype = local_permuted.dtype - - symm_global_permuted = _get_symm_buffers(local_permuted.device, hidden_dim, world_size, num_experts, local_permuted.dtype)["buf_global_permuted"] - return symm_global_permuted[:total_received].clone() - - @staticmethod - def backward(ctx, grad_global_permuted): - global_counts_cpu, = ctx.saved_tensors - src_offsets, dst_ptrs, chunk_sizes = compute_backward_scatter_args( - global_counts_cpu, ctx.rank, ctx.world_size, ctx.num_experts, ctx.hidden_dim, ctx.element_size, ctx.symm_grad_local_ptrs_cpu) - - dist.barrier(group=ctx.group) - _get_ext().launch_push_chunks( - grad_global_permuted.contiguous(), - src_offsets.to(grad_global_permuted.device), - dst_ptrs.to(grad_global_permuted.device), - chunk_sizes.to(grad_global_permuted.device), - ctx.hidden_dim * ctx.element_size - ) - dist.barrier(group=ctx.group) - - symm_grad_local = _get_symm_buffers(grad_global_permuted.device, ctx.hidden_dim, ctx.world_size, ctx.num_experts, ctx.dtype)["buf_grad_local"] - total_local = global_counts_cpu[ctx.rank].sum().item() - return symm_grad_local[:total_local].clone(), None, None, None, None, None, None, None, None, None - - -class MoEGather(torch.autograd.Function): - @staticmethod - def forward(ctx, expert_outputs, global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu, group, symm_grad_expert_ptrs_cpu): - src_offsets, dst_ptrs, chunk_sizes = compute_forward_gather_args( - global_counts_cpu, rank, world_size, num_experts, hidden_dim, element_size, symm_ptrs_cpu) - - dist.barrier(group=group) - _get_ext().launch_push_chunks( - expert_outputs.contiguous(), - src_offsets.to(expert_outputs.device), - dst_ptrs.to(expert_outputs.device), - chunk_sizes.to(expert_outputs.device), - hidden_dim * element_size - ) - dist.barrier(group=group) - - total_received = global_counts_cpu[rank].sum().item() - - ctx.save_for_backward(global_counts_cpu) - ctx.rank = rank - ctx.world_size = world_size - ctx.num_experts = num_experts - ctx.hidden_dim = hidden_dim - ctx.element_size = element_size - ctx.symm_grad_expert_ptrs_cpu = symm_grad_expert_ptrs_cpu - ctx.group = group - ctx.dtype = expert_outputs.dtype - - symm_unpermute = _get_symm_buffers(expert_outputs.device, hidden_dim, world_size, num_experts, expert_outputs.dtype)["buf_unpermute"] - return symm_unpermute[:total_received].clone() - - @staticmethod - def backward(ctx, grad_unpermute): - global_counts_cpu, = ctx.saved_tensors - src_offsets, dst_ptrs, chunk_sizes = compute_backward_gather_args( - global_counts_cpu, ctx.rank, ctx.world_size, ctx.num_experts, ctx.hidden_dim, ctx.element_size, ctx.symm_grad_expert_ptrs_cpu) - - dist.barrier(group=ctx.group) - _get_ext().launch_push_chunks( - grad_unpermute.contiguous(), - src_offsets.to(grad_unpermute.device), - dst_ptrs.to(grad_unpermute.device), - chunk_sizes.to(grad_unpermute.device), - ctx.hidden_dim * ctx.element_size - ) - dist.barrier(group=ctx.group) - - symm_grad_expert = _get_symm_buffers(grad_unpermute.device, ctx.hidden_dim, ctx.world_size, ctx.num_experts, ctx.dtype)["buf_grad_expert"] - num_local_experts = ctx.num_experts // ctx.world_size - total_local = global_counts_cpu[:, ctx.rank * num_local_experts : (ctx.rank + 1) * num_local_experts].sum().item() - - return symm_grad_expert[:total_local].clone(), None, None, None, None, None, None, None, None, None - - -def expert_forward( - x: torch.Tensor, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, -) -> torch.Tensor: - gate = torch.nn.functional.silu(gate_proj(x)) - up = up_proj(x) - return down_proj(gate * up) - - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - device = hidden_states.device - hidden_dim = hidden_states.size(-1) - element_size = hidden_states.element_size() - - if rank == 0: - _get_ext() - dist.barrier(group=group) - - # 1. Routing - router_logits = torch.nn.functional.linear( - hidden_states.reshape(-1, hidden_dim), gate_weight, gate_bias - ) - routing_weights, selected_experts = torch.topk( - torch.softmax(router_logits, dim=-1), top_k, dim=-1 - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts - ).permute(2, 1, 0) - - routing_map = expert_mask.sum(dim=1).bool() - local_counts = routing_map.sum(dim=1).to(torch.int32) - - symm_res = _get_symm_buffers(device, hidden_dim, world_size, num_experts, hidden_states.dtype) - - # 2. Collect precise sizes with UVA - symm_res["buf_counts"].copy_(local_counts) - dist.barrier(group=group) - - global_counts = torch.empty((world_size, num_experts), dtype=torch.int32, device=device) - peer_ptrs = torch.tensor(symm_res["hdl_counts"].buffer_ptrs, dtype=torch.int64, device=device) - _get_ext().launch_gather_counts(peer_ptrs, global_counts, world_size, num_experts) - global_counts_cpu = global_counts.cpu() - - # 3. Fast PyTorch-local permute - num_tokens = hidden_states.size(0) - token_indices = torch.arange(num_tokens, device=device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map) - local_permuted_hidden_states = hidden_states.index_select(0, sorted_indices) - - # 4. Forward Dispatch (UVA Scatter) - global_permuted_hidden_states = MoEScatter.apply( - local_permuted_hidden_states, - global_counts_cpu, - rank, - world_size, - num_experts, - hidden_dim, - element_size, - symm_res["hdl_global_permuted"].buffer_ptrs, - group, - symm_res["hdl_grad_local"].buffer_ptrs - ) - - # 5. Local Expert - expert_outputs = expert_forward( - global_permuted_hidden_states, gate_proj, up_proj, down_proj - ) - - # 6. Gather (UVA Reverse Scatter) - unpermute_outputs = MoEGather.apply( - expert_outputs, - global_counts_cpu, - rank, - world_size, - num_experts, - hidden_dim, - element_size, - symm_res["hdl_unpermute"].buffer_ptrs, - group, - symm_res["hdl_grad_expert"].buffer_ptrs - ) - - # 7. Unpermute via natively propagated Autograd hooks - weights_idx = torch.zeros( - (num_tokens, num_experts), - dtype=routing_weights.dtype, - device=device, - ) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - tokens_weight = weights_idx.T.contiguous().masked_select(routing_map) - tokens = unpermute_outputs * tokens_weight.unsqueeze(-1) - - unpermuted_tokens = torch.zeros( - hidden_states.shape, device=device, dtype=hidden_states.dtype - ) - expanded_mapping = sorted_indices.unsqueeze(1).expand(-1, hidden_dim) - unpermuted_tokens.scatter_add_(0, expanded_mapping, tokens) - - return unpermuted_tokens.reshape(hidden_states.shape) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/52_moe_ep_narrow_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/52_moe_ep_narrow_cuda.py deleted file mode 100755 index 73743b3..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/52_moe_ep_narrow_cuda.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -Strategy: -1. Replace NCCL `all_to_all` and PyTorch chunk sorting with a single fused CUDA operator. -2. We compute token routing distributions (`G`) using a fast symmetric memory `all_gather` instead of `dist.all_gather_into_tensor`. -3. We implement a custom autograd function (`FusedP2PAllToAll`) that calculates chunk offsets exactly mapping the local tokens to their final grouped-by-expert destination positions in remote symmetric memory. -4. A single vectorized UVA push kernel simultaneously sends data over NVLink AND sorts the chunks, replacing both `_all_to_all` and `_sort_chunks_by_idxs` in one step. -5. In the backward pass, the exact inverse push operation efficiently returns gradients to the source buffers, minimizing host-device syncs and maximizing bandwidth utilization. -""" - -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void push_chunks_kernel_vec( - const float4* __restrict__ local_data, - const long long* __restrict__ remote_ptrs, - const int* __restrict__ src_offsets, - const int* __restrict__ dst_offsets, - const int* __restrict__ dst_ranks, - int num_chunks, - int vec_hidden_dim, - int total_vecs -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_vecs) return; - - int token_idx = idx / vec_hidden_dim; - int dim_idx = idx % vec_hidden_dim; - - int chunk = 0; - // Linear search is fast since num_chunks is very small (e.g., <= 64) - for (int i = 0; i < num_chunks; ++i) { - if (token_idx >= src_offsets[i] && token_idx < src_offsets[i+1]) { - chunk = i; - break; - } - } - - int token_offset_in_chunk = token_idx - src_offsets[chunk]; - int dst_rank = dst_ranks[chunk]; - int dst_off = dst_offsets[chunk]; - - float4* dst = (float4*)remote_ptrs[dst_rank]; - dst[(dst_off + token_offset_in_chunk) * vec_hidden_dim + dim_idx] = local_data[idx]; -} - -void push_chunks_vec( - torch::Tensor local_data, - torch::Tensor remote_ptrs, - int64_t src_offsets_ptr, - int64_t dst_offsets_ptr, - int64_t dst_ranks_ptr, - int num_chunks, - int vec_hidden_dim, - int total_vecs -) { - int threads = 256; - int blocks = (total_vecs + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const long long* ptrs = (const long long*)remote_ptrs.data_ptr(); - const int* src_off = reinterpret_cast(src_offsets_ptr); - const int* dst_off = reinterpret_cast(dst_offsets_ptr); - const int* dst_r = reinterpret_cast(dst_ranks_ptr); - - push_chunks_kernel_vec<<>>( - (const float4*)local_data.data_ptr(), - ptrs, - src_off, dst_off, dst_r, - num_chunks, vec_hidden_dim, total_vecs - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -__global__ void symm_all_gather_kernel( - const int* __restrict__ local_data, - const long long* __restrict__ remote_ptrs, - int rank, - int ep_size, - int num_experts -) { - int tid = threadIdx.x; - if (tid < num_experts) { - int val = local_data[tid]; - for (int dst = 0; dst < ep_size; ++dst) { - int* dst_ptr = (int*)remote_ptrs[dst]; - dst_ptr[rank * num_experts + tid] = val; - } - } -} - -void symm_all_gather( - torch::Tensor local_data, - torch::Tensor remote_ptrs, - int rank, - int ep_size, - int num_experts -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const long long* ptrs = (const long long*)remote_ptrs.data_ptr(); - - symm_all_gather_kernel<<<1, 32, 0, stream>>>( - local_data.data_ptr(), - ptrs, - rank, ep_size, num_experts - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("push_chunks_vec", &push_chunks_vec, "UVA chunk copy and sort"); - m.def("symm_all_gather", &symm_all_gather, "UVA all gather for G"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_moe_symm_p2p", CUDA_SRC) - return _ext - - -_EP_SUBGROUP_CACHE: dict[tuple[int, int], Union[None, list]] = {} -def _resolve_ep_group_for_narrow_moe(num_experts: int) -> dist.ProcessGroup: - if not dist.is_initialized(): - raise RuntimeError("torch.distributed must be initialized") - ws = dist.get_world_size() - rank = dist.get_rank() - key = (ws, num_experts) - if key not in _EP_SUBGROUP_CACHE: - if num_experts >= ws: - _EP_SUBGROUP_CACHE[key] = None - elif ws % num_experts != 0: - raise ValueError(f"narrow EP requires world_size ({ws}) % num_experts ({num_experts}) == 0") - else: - groups: list = [] - for r in range(ws // num_experts): - ranks = list(range(r * num_experts, (r + 1) * num_experts)) - groups.append(dist.new_group(ranks)) - _EP_SUBGROUP_CACHE[key] = groups - entry = _EP_SUBGROUP_CACHE[key] - if entry is None: - return dist.group.WORLD - return entry[rank // num_experts] - - -_SYMM_BUFS = {} -def get_symm_buf(ep_group, buffer_id, max_elements, dtype, device): - global _SYMM_BUFS - key = (ep_group, buffer_id, dtype, device) - if key not in _SYMM_BUFS: - buf = symm_mem.empty((max_elements,), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group=ep_group) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - _SYMM_BUFS[key] = (buf, hdl, ptrs) - return _SYMM_BUFS[key] - - -def _preprocess_symm(selected_experts: torch.Tensor, num_experts: int, ep_group: dist.ProcessGroup): - ep_size = dist.get_world_size(ep_group) - rank = dist.get_rank(ep_group) - device = selected_experts.device - - routing_map = torch.zeros((num_experts, selected_experts.size(0)), dtype=torch.bool, device=device) - routing_map.scatter_(0, selected_experts.T, 1) - num_local_tokens_per_expert = routing_map.sum(dim=1, dtype=torch.int32) - - buf, hdl, remote_ptrs = get_symm_buf(ep_group, 'G', ep_size * num_experts, torch.int32, device) - - hdl.barrier() - _get_ext().symm_all_gather(num_local_tokens_per_expert, remote_ptrs, rank, ep_size, num_experts) - hdl.barrier() - - G = buf[:ep_size * num_experts].view(ep_size, num_experts) - G_cpu = G.cpu().tolist() - - return routing_map, G_cpu - - -class FusedP2PAllToAll(torch.autograd.Function): - @staticmethod - def forward(ctx, permuted_input, op_type, G_cpu, ep_group, num_experts): - ctx.op_type = op_type - ctx.G_cpu = G_cpu - ctx.ep_group = ep_group - ctx.num_experts = num_experts - - ep_size = dist.get_world_size(ep_group) - rank = dist.get_rank(ep_group) - num_local_experts = num_experts // ep_size - E_start = rank * num_local_experts - E_end = E_start + num_local_experts - - device = permuted_input.device - hidden_dim = permuted_input.size(-1) - - if op_type == 1: - total_send = permuted_input.size(0) - total_recv = sum(G_cpu[r][e] for e in range(E_start, E_end) for r in range(ep_size)) - num_chunks = num_experts - - src_offsets = [0] * (num_chunks + 1) - dst_offsets = [0] * num_chunks - dst_ranks = [0] * num_chunks - - for e in range(num_experts): - dst = e // num_local_experts - dst_ranks[e] = dst - src_offsets[e] = sum(G_cpu[rank][k] for k in range(e)) - base_e = sum(G_cpu[r][k] for k in range(dst * num_local_experts, e) for r in range(ep_size)) - rank_off = sum(G_cpu[r][e] for r in range(rank)) - dst_offsets[e] = base_e + rank_off - src_offsets[num_experts] = total_send - buf_id = 'pre_fwd' - - else: - total_send = permuted_input.size(0) - total_recv = sum(G_cpu[rank][e] for e in range(num_experts)) - num_chunks = ep_size * num_local_experts - - src_offsets = [0] * (num_chunks + 1) - dst_offsets = [0] * num_chunks - dst_ranks = [0] * num_chunks - - chunk_idx = 0 - cur_off = 0 - for e in range(E_start, E_end): - for s in range(ep_size): - dst_ranks[chunk_idx] = s - src_offsets[chunk_idx] = cur_off - dst_offsets[chunk_idx] = sum(G_cpu[s][k] for k in range(e)) - cur_off += G_cpu[s][e] - chunk_idx += 1 - src_offsets[num_chunks] = cur_off - buf_id = 'post_fwd' - - ctx.total_recv_fwd = total_recv - ctx.total_send_fwd = total_send - - max_tokens = 65536 - buf, hdl, remote_ptrs = get_symm_buf(ep_group, buf_id, max_tokens * hidden_dim, permuted_input.dtype, device) - - hdl.barrier() - - offsets_tensor = torch.tensor(src_offsets + dst_offsets + dst_ranks, dtype=torch.int32, device=device) - d_src_offsets = offsets_tensor.data_ptr() - d_dst_offsets = d_src_offsets + len(src_offsets) * 4 - d_dst_ranks = d_dst_offsets + len(dst_offsets) * 4 - - vec_hidden_dim = hidden_dim - if hidden_dim % 8 == 0 and permuted_input.dtype == torch.bfloat16: - vec_hidden_dim = hidden_dim // 8 - elif hidden_dim % 4 == 0 and permuted_input.dtype == torch.float32: - vec_hidden_dim = hidden_dim // 4 - - total_vecs = total_send * vec_hidden_dim - if total_vecs > 0: - _get_ext().push_chunks_vec( - permuted_input.contiguous(), remote_ptrs, - d_src_offsets, d_dst_offsets, d_dst_ranks, - num_chunks, vec_hidden_dim, total_vecs - ) - - hdl.barrier() - - out = torch.empty((total_recv, hidden_dim), dtype=permuted_input.dtype, device=device) - if total_recv > 0: - out.copy_(buf[:total_recv * hidden_dim].view(total_recv, hidden_dim)) - - return out - - @staticmethod - def backward(ctx, grad_output): - op_type = ctx.op_type - G_cpu = ctx.G_cpu - ep_group = ctx.ep_group - num_experts = ctx.num_experts - - ep_size = dist.get_world_size(ep_group) - rank = dist.get_rank(ep_group) - num_local_experts = num_experts // ep_size - E_start = rank * num_local_experts - E_end = E_start + num_local_experts - - device = grad_output.device - hidden_dim = grad_output.size(-1) - - total_send = grad_output.size(0) - total_recv = ctx.total_send_fwd - - if op_type == 1: - num_chunks = ep_size * num_local_experts - src_offsets = [0] * (num_chunks + 1) - dst_offsets = [0] * num_chunks - dst_ranks = [0] * num_chunks - - chunk_idx = 0 - cur_off = 0 - for e in range(E_start, E_end): - for s in range(ep_size): - dst_ranks[chunk_idx] = s - src_offsets[chunk_idx] = cur_off - dst_offsets[chunk_idx] = sum(G_cpu[s][k] for k in range(e)) - cur_off += G_cpu[s][e] - chunk_idx += 1 - src_offsets[num_chunks] = cur_off - buf_id = 'pre_bwd' - else: - num_chunks = num_experts - src_offsets = [0] * (num_chunks + 1) - dst_offsets = [0] * num_chunks - dst_ranks = [0] * num_chunks - - for e in range(num_experts): - dst = e // num_local_experts - dst_ranks[e] = dst - src_offsets[e] = sum(G_cpu[rank][k] for k in range(e)) - base_e = sum(G_cpu[r][k] for k in range(dst * num_local_experts, e) for r in range(ep_size)) - rank_off = sum(G_cpu[r][e] for r in range(rank)) - dst_offsets[e] = base_e + rank_off - src_offsets[num_experts] = total_send - buf_id = 'post_bwd' - - max_tokens = 65536 - buf, hdl, remote_ptrs = get_symm_buf(ep_group, buf_id, max_tokens * hidden_dim, grad_output.dtype, device) - - hdl.barrier() - - offsets_tensor = torch.tensor(src_offsets + dst_offsets + dst_ranks, dtype=torch.int32, device=device) - d_src_offsets = offsets_tensor.data_ptr() - d_dst_offsets = d_src_offsets + len(src_offsets) * 4 - d_dst_ranks = d_dst_offsets + len(dst_offsets) * 4 - - vec_hidden_dim = hidden_dim - if hidden_dim % 8 == 0 and grad_output.dtype == torch.bfloat16: - vec_hidden_dim = hidden_dim // 8 - elif hidden_dim % 4 == 0 and grad_output.dtype == torch.float32: - vec_hidden_dim = hidden_dim // 4 - - total_vecs = total_send * vec_hidden_dim - if total_vecs > 0: - _get_ext().push_chunks_vec( - grad_output.contiguous(), remote_ptrs, - d_src_offsets, d_dst_offsets, d_dst_ranks, - num_chunks, vec_hidden_dim, total_vecs - ) - - hdl.barrier() - - grad_input = torch.empty((total_recv, hidden_dim), dtype=grad_output.dtype, device=device) - if total_recv > 0: - grad_input.copy_(buf[:total_recv * hidden_dim].view(total_recv, hidden_dim)) - - return grad_input, None, None, None, None - - -def _permute(tokens: torch.Tensor, routing_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - num_tokens, _ = tokens.shape - num_experts = routing_map.shape[0] - token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) - sorted_indices = token_indices.masked_select(routing_map) - permuted_input = tokens.index_select(0, sorted_indices) - return permuted_input, sorted_indices - - -def _generate_weights_idx(routing_weights: torch.Tensor, selected_experts: torch.Tensor, num_experts: int) -> torch.Tensor: - num_tokens, topk = routing_weights.shape - weights_idx = torch.zeros((num_tokens, num_experts), dtype=routing_weights.dtype, device=routing_weights.device) - weights_idx.scatter_add_(1, selected_experts, routing_weights) - return weights_idx - - -def _unpermute( - tokens: torch.Tensor, routing_weights: torch.Tensor, hidden_states_shape: torch.Size, - permutation_mapping: torch.Tensor, routing_map: torch.Tensor -) -> torch.Tensor: - tokens_weight = routing_weights.T.contiguous().masked_select(routing_map) - tokens = tokens * tokens_weight.unsqueeze(-1) - hidden_dim = hidden_states_shape[-1] - unpermuted_tokens = torch.zeros(hidden_states_shape, device=tokens.device, dtype=tokens.dtype) - expanded_mapping = permutation_mapping.unsqueeze(1).expand(-1, hidden_dim) - unpermuted_tokens.scatter_add_(0, expanded_mapping, tokens) - return unpermuted_tokens - - -def token_pre_all2all( - hidden_states: torch.Tensor, routing_map: torch.Tensor, G_cpu: List[List[int]], - group: dist.ProcessGroup, num_experts: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Size]: - hidden_dim = hidden_states.size(-1) - hidden_states = hidden_states.reshape(-1, hidden_dim) - org_hidden_states_shape = hidden_states.shape - - local_permuted_hidden_states, local_input_permutation_mapping = _permute(hidden_states, routing_map) - - global_permuted_hidden_states = FusedP2PAllToAll.apply( - local_permuted_hidden_states, 1, G_cpu, group, num_experts - ) - return global_permuted_hidden_states, local_input_permutation_mapping, org_hidden_states_shape - - -def tokens_post_all2all( - expert_outputs: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor, - num_experts: int, routing_map: torch.Tensor, local_input_permutation_mapping: torch.Tensor, - org_hidden_states_shape: torch.Size, G_cpu: List[List[int]], group: dist.ProcessGroup -) -> torch.Tensor: - unpermute_outputs = FusedP2PAllToAll.apply( - expert_outputs, 2, G_cpu, group, num_experts - ) - weights_idx = _generate_weights_idx(routing_weights, selected_experts, num_experts) - out = _unpermute( - unpermute_outputs, weights_idx, org_hidden_states_shape, - local_input_permutation_mapping, routing_map - ) - return out - - -def expert_forward( - x: torch.Tensor, gate_proj: torch.nn.Linear, up_proj: torch.nn.Linear, down_proj: torch.nn.Linear -) -> torch.Tensor: - gate = torch.nn.functional.silu(gate_proj(x)) - up = up_proj(x) - return down_proj(gate * up) - - -def solution( - hidden_states: torch.Tensor, - gate_weight: torch.Tensor, - gate_bias: Optional[torch.Tensor], - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - num_experts: int, - top_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - One MoE forward with completely custom Fused UVA backend replacing all collectives. - """ - if group is None: - group = _resolve_ep_group_for_narrow_moe(num_experts) - - hidden_dim = hidden_states.size(-1) - - router_logits = torch.nn.functional.linear( - hidden_states.reshape(-1, hidden_dim), gate_weight, gate_bias - ) - routing_weights, selected_experts = torch.topk( - torch.softmax(router_logits, dim=-1), top_k, dim=-1 - ) - - _get_ext() # Init JIT extension - - routing_map, G_cpu = _preprocess_symm(selected_experts, num_experts, group) - - global_permuted, local_input_permutation_mapping, org_shape = token_pre_all2all( - hidden_states, routing_map, G_cpu, group, num_experts - ) - - expert_outputs = expert_forward( - global_permuted, gate_proj, up_proj, down_proj - ) - - out = tokens_post_all2all( - expert_outputs, routing_weights, selected_experts, num_experts, - routing_map, local_input_permutation_mapping, org_shape, - G_cpu, group - ) - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/53_fp8_reduce_scatter_grads_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/53_fp8_reduce_scatter_grads_cuda.py deleted file mode 100755 index 7108108..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/53_fp8_reduce_scatter_grads_cuda.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -Strategy: -- **Device-Side Fusion:** Replaced the host-driven PyTorch simulation (FP32 conversion -> FP8 quantize -> BF16 dequantize -> NCCL reduce-scatter) with pure device-side kernels. The rolling history and dynamic scaling operate asynchronously on the GPU stream, entirely sidestepping host synchronization. -- **True FP8 Wire Protocol via Symmetric Memory:** Skipped standard NCCL completely. A custom fused `quantize_kernel` directly scales and converts BF16 gradients into FP8 E4M3, then pushes them to a `symm_mem` buffer alongside the local scaling factor. This inherently cuts the peer-to-peer NVLink communication payload in half (from BF16 to FP8). -- **Zero-Copy Dequantize & Reduce-Scatter:** A custom `reduce_scatter_kernel` directly taps into the FP8 symmetric buffers of all peers. It performs vectorized (4-element `uint32_t`) zero-copy reads, dequantizes on the fly using the gathered per-rank scales, accurately accumulates the global sum in float precision, divides by `world_size`, and stores the final averaged BF16 shard directly to the local output tensor. -- **Compute-Communication Overlap & Maximised Bandwidth:** The kernels utilise vectorized 64-bit/32-bit loads to push Hopper's memory bandwidth to the limit. We broadcast the scale factors via fast block shared memory. Inter-rank synchronisation strictly relies on lightweight device-stream barriers via `symm_mem`. -""" - -from __future__ import annotations - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void quantize_kernel_vec4( - const at::BFloat16* __restrict__ input, - const float* __restrict__ scale, - c10::Float8_e4m3fn* __restrict__ output, - float* __restrict__ symm_scale, - int64_t n -) { - float s = *scale; - if (threadIdx.x == 0 && blockIdx.x == 0) { - *symm_scale = s; - } - float inv_s = 1.0f / s; - int64_t n_vec4 = n / 4; - - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n_vec4; idx += (int64_t)gridDim.x * blockDim.x) { - uint64_t in4 = reinterpret_cast(input)[idx]; - at::BFloat16 v0 = reinterpret_cast(&in4)[0]; - at::BFloat16 v1 = reinterpret_cast(&in4)[1]; - at::BFloat16 v2 = reinterpret_cast(&in4)[2]; - at::BFloat16 v3 = reinterpret_cast(&in4)[3]; - - c10::Float8_e4m3fn q0 = static_cast(static_cast(v0) * inv_s); - c10::Float8_e4m3fn q1 = static_cast(static_cast(v1) * inv_s); - c10::Float8_e4m3fn q2 = static_cast(static_cast(v2) * inv_s); - c10::Float8_e4m3fn q3 = static_cast(static_cast(v3) * inv_s); - - uint32_t out4; - reinterpret_cast(&out4)[0] = q0; - reinterpret_cast(&out4)[1] = q1; - reinterpret_cast(&out4)[2] = q2; - reinterpret_cast(&out4)[3] = q3; - - reinterpret_cast(output)[idx] = out4; - } - - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int64_t i = n_vec4 * 4; i < n; ++i) { - output[i] = static_cast(static_cast(input[i]) * inv_s); - } - } -} - -__global__ void reduce_scatter_kernel_vec4( - const uint64_t* __restrict__ peer_fp8_ptrs, - const uint64_t* __restrict__ peer_scale_ptrs, - at::BFloat16* __restrict__ out_shard, - int world_size, - int rank, - int64_t shard_elems -) { - extern __shared__ float shared_scales[]; - if (threadIdx.x < world_size) { - const float* scale_ptr = reinterpret_cast(peer_scale_ptrs[threadIdx.x]); - shared_scales[threadIdx.x] = *scale_ptr; - } - __syncthreads(); - - int64_t vec_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t shard_vec4 = shard_elems / 4; - - for (; vec_idx < shard_vec4; vec_idx += (int64_t)gridDim.x * blockDim.x) { - float sum[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - int64_t global_vec_idx = rank * shard_vec4 + vec_idx; - - #pragma unroll - for (int p = 0; p < world_size; ++p) { - float scale = shared_scales[p]; - const uint32_t* fp8_ptr = reinterpret_cast(peer_fp8_ptrs[p]); - uint32_t q4 = fp8_ptr[global_vec_idx]; - - c10::Float8_e4m3fn q0 = reinterpret_cast(&q4)[0]; - c10::Float8_e4m3fn q1 = reinterpret_cast(&q4)[1]; - c10::Float8_e4m3fn q2 = reinterpret_cast(&q4)[2]; - c10::Float8_e4m3fn q3 = reinterpret_cast(&q4)[3]; - - sum[0] += static_cast(q0) * scale; - sum[1] += static_cast(q1) * scale; - sum[2] += static_cast(q2) * scale; - sum[3] += static_cast(q3) * scale; - } - - float inv_ws = 1.0f / world_size; - at::BFloat16 out0 = static_cast(sum[0] * inv_ws); - at::BFloat16 out1 = static_cast(sum[1] * inv_ws); - at::BFloat16 out2 = static_cast(sum[2] * inv_ws); - at::BFloat16 out3 = static_cast(sum[3] * inv_ws); - - uint64_t out4; - reinterpret_cast(&out4)[0] = out0; - reinterpret_cast(&out4)[1] = out1; - reinterpret_cast(&out4)[2] = out2; - reinterpret_cast(&out4)[3] = out3; - - reinterpret_cast(out_shard)[vec_idx] = out4; - } - - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int64_t i = shard_vec4 * 4; i < shard_elems; ++i) { - float sum = 0.0f; - int64_t global_idx = rank * shard_elems + i; - for (int p = 0; p < world_size; ++p) { - float scale = shared_scales[p]; - const c10::Float8_e4m3fn* fp8_ptr = reinterpret_cast(peer_fp8_ptrs[p]); - c10::Float8_e4m3fn q = fp8_ptr[global_idx]; - sum += static_cast(q) * scale; - } - out_shard[i] = static_cast(sum / world_size); - } - } -} - -void launch_quantize( - torch::Tensor input, - torch::Tensor scale, - torch::Tensor output, - torch::Tensor symm_scale -) { - int64_t n = input.numel(); - int threads = 512; - int blocks = std::max(1, std::min((int)((n/4 + threads - 1) / threads), 65535)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - quantize_kernel_vec4<<>>( - input.data_ptr(), - scale.data_ptr(), - reinterpret_cast(output.data_ptr()), - symm_scale.data_ptr(), - n - ); -} - -void launch_reduce_scatter( - torch::Tensor peer_fp8_ptrs_tensor, - torch::Tensor peer_scale_ptrs_tensor, - torch::Tensor out_shard, - int world_size, - int rank, - int64_t shard_elems -) { - int threads = 512; - int blocks = std::max(1, std::min((int)((shard_elems/4 + threads - 1) / threads), 65535)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - size_t shared_mem_size = world_size * sizeof(float); - - reduce_scatter_kernel_vec4<<>>( - reinterpret_cast(peer_fp8_ptrs_tensor.data_ptr()), - reinterpret_cast(peer_scale_ptrs_tensor.data_ptr()), - out_shard.data_ptr(), - world_size, - rank, - shard_elems - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_quantize", &launch_quantize); - m.def("launch_reduce_scatter", &launch_reduce_scatter); -} -''' - -_ext = None -_symm_cache = {} - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fp8_rs_cuda_ext", CUDA_SRC) - return _ext - -def _get_symm_state(n: int, device: torch.device): - global _symm_cache - if n in _symm_cache: - return _symm_cache[n] - - # FP8 E4M3 symmetric buffer - fp8_buf = symm_mem.empty(n, dtype=torch.float8_e4m3fn, device=device) - hdl_fp8 = symm_mem.rendezvous(fp8_buf, dist.group.WORLD) - fp8_ptrs = torch.tensor(hdl_fp8.buffer_ptrs, dtype=torch.int64, device=device) - - # Scale per-rank symmetric buffer - scale_buf = symm_mem.empty(1, dtype=torch.float32, device=device) - hdl_scale = symm_mem.rendezvous(scale_buf, dist.group.WORLD) - scale_ptrs = torch.tensor(hdl_scale.buffer_ptrs, dtype=torch.int64, device=device) - - state = (fp8_buf, hdl_fp8, fp8_ptrs, scale_buf, hdl_scale, scale_ptrs) - _symm_cache[n] = state - return state - - -@torch.no_grad() -def solution(flat_grads: Tensor, amax_history: Tensor) -> tuple[Tensor, Tensor]: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - rank = dist.get_rank() - flat_grads = flat_grads.contiguous() - n = flat_grads.numel() - shard_elems = n // world_size - - assert n % world_size == 0, f"flat_grads numel {n} must be divisible by world_size {world_size}" - - ext = _get_ext() - fp8_buf, hdl_fp8, fp8_ptrs, scale_buf, hdl_scale, scale_ptrs = _get_symm_state(n, flat_grads.device) - - # Completely asynchronous calculations on the device; no host/GPU sync invoked. - cur_abs_max = flat_grads.abs().max().to(torch.float32) - updated_hist = torch.roll(amax_history, shifts=-1, dims=0) - updated_hist[-1] = cur_abs_max - - scale = updated_hist.max().clamp(min=1e-12) / 448.0 - - # Ensure any previous reduction's reads on the symmetric buffers have cleanly finished. - hdl_fp8.barrier(channel=0) - - # 1. Fuse scale-out and convert BF16 -> FP8 + store scale to device symm memory - ext.launch_quantize(flat_grads, scale, fp8_buf, scale_buf) - - # Wait for all peers to write their FP8 arrays and scalar multipliers to symmetric memory - hdl_fp8.barrier(channel=0) - - out_shard = torch.empty(shard_elems, dtype=flat_grads.dtype, device=flat_grads.device) - - # 2. Fully fused peer-reads of FP8, dequantize using respective scales, average, and save to BF16 shard - ext.launch_reduce_scatter(fp8_ptrs, scale_ptrs, out_shard, world_size, rank, shard_elems) - - return out_shard, updated_hist - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/54_fp8_allgather_params_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/54_fp8_allgather_params_cuda.py deleted file mode 100755 index d03d99e..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/54_fp8_allgather_params_cuda.py +++ /dev/null @@ -1,280 +0,0 @@ -""" -Strategy: -1. **Pipelined device-side communication**: Instead of a PyTorch-level BF16 all-gather after a local FP8 round-trip, we directly fuse the quantization and communicate the *compressed FP8 buffers* across peers using `torch.distributed._symmetric_memory`. -2. **Fused Gather + Dequantize**: A custom multi-block CUDA kernel utilizes direct peer-to-peer memory access over NVLink (pull-based). Each block pulls FP8 parameter shards and `scale` variables from its designated peer, dequantizes them locally on the fly, and streams the restored BF16 values straight into the full output tensor. -3. **Optimized NVLink throughput**: Uses vectorized 16-byte memory instructions (`uint4`) for all global memory reads and writes over NVLink, saturating the bus bandwidth and significantly speeding up the memory-bound all-gather. -4. **No host syncs**: We use `scale` variables directly via device pointers between our custom kernels without copying back to CPU, keeping the execution entirely asynchronous. -""" - -from __future__ import annotations - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from torch import Tensor -from utils.cuda_helpers import compile_cuda_extension - -_FP8_E4M3_MAX = 448.0 - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__device__ __forceinline__ float e4m3_to_float(uint8_t x) { - float res; - uint32_t ix = x; - asm volatile("cvt.f32.e4m3 %0, %1;" : "=f"(res) : "r"(ix)); - return res; -} - -__global__ void quantize_fused_kernel( - const __nv_bfloat16* __restrict__ input, - uint8_t* __restrict__ out_fp8, - const float* __restrict__ scale_ptr, - int64_t p -) { - float scale = *scale_ptr; - float inv_scale = 1.0f / scale; - - bool aligned = (((uintptr_t)input % 16) == 0) && (((uintptr_t)out_fp8 % 16) == 0); - - if (aligned) { - int64_t p_16 = p / 16; - int64_t offset_16 = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride_16 = (int64_t)gridDim.x * blockDim.x; - - for (int64_t i = offset_16; i < p_16; i += stride_16) { - uint4 in_bf16_0 = reinterpret_cast(input)[i * 2]; - uint4 in_bf16_1 = reinterpret_cast(input)[i * 2 + 1]; - - __nv_bfloat16 bf16_vals[16]; - ((uint4*)bf16_vals)[0] = in_bf16_0; - ((uint4*)bf16_vals)[1] = in_bf16_1; - - uint8_t bytes[16]; - #pragma unroll - for(int j=0; j<16; ++j) { - float val_f32 = __bfloat162float(bf16_vals[j]); - float scaled = val_f32 * inv_scale; - uint32_t fp8_val; - asm volatile("cvt.rn.satfinite.e4m3.f32 %0, %1;" : "=r"(fp8_val) : "f"(scaled)); - bytes[j] = (uint8_t)fp8_val; - } - - reinterpret_cast(out_fp8)[i] = *(uint4*)bytes; - } - - if (threadIdx.x == 0 && blockIdx.x == 0) { - for (int64_t i = p_16 * 16; i < p; ++i) { - float val_f32 = __bfloat162float(input[i]); - float scaled = val_f32 * inv_scale; - uint32_t fp8_val; - asm volatile("cvt.rn.satfinite.e4m3.f32 %0, %1;" : "=r"(fp8_val) : "f"(scaled)); - out_fp8[i] = (uint8_t)fp8_val; - } - } - } else { - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - for (int64_t i = tid; i < p; i += stride) { - float val_f32 = __bfloat162float(input[i]); - float scaled = val_f32 * inv_scale; - uint32_t fp8_val; - asm volatile("cvt.rn.satfinite.e4m3.f32 %0, %1;" : "=r"(fp8_val) : "f"(scaled)); - out_fp8[i] = (uint8_t)fp8_val; - } - } -} - -__global__ void dequantize_and_gather_kernel( - const uint64_t* __restrict__ peer_fp8_ptrs, - const uint64_t* __restrict__ peer_scale_ptrs, - __nv_bfloat16* __restrict__ out_full, - int64_t p, - int world_size -) { - int r = blockIdx.y; - if (r >= world_size) return; - - __shared__ float shared_scale; - if (threadIdx.x == 0) { - const float* scale_ptr = reinterpret_cast(peer_scale_ptrs[r]); - shared_scale = *scale_ptr; - } - __syncthreads(); - - const uint8_t* src_fp8 = reinterpret_cast(peer_fp8_ptrs[r]); - float scale = shared_scale; - __nv_bfloat16* out_rank_ptr = out_full + (int64_t)r * p; - - bool aligned = (((uintptr_t)src_fp8 % 16) == 0) && (((uintptr_t)out_rank_ptr % 16) == 0); - - if (aligned) { - int64_t p_16 = p / 16; - int64_t offset_16 = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride_16 = (int64_t)gridDim.x * blockDim.x; - - for (int64_t i = offset_16; i < p_16; i += stride_16) { - uint4 fp8_vec = reinterpret_cast(src_fp8)[i]; - - uint8_t bytes[16]; - *(uint4*)bytes = fp8_vec; - - __nv_bfloat16 out_bf16[16]; - #pragma unroll - for(int j=0; j<16; ++j) { - float val_f32 = e4m3_to_float(bytes[j]); - out_bf16[j] = __float2bfloat16(val_f32 * scale); - } - - reinterpret_cast(out_rank_ptr)[i * 2] = ((uint4*)out_bf16)[0]; - reinterpret_cast(out_rank_ptr)[i * 2 + 1] = ((uint4*)out_bf16)[1]; - } - - if (threadIdx.x == 0 && blockIdx.x == 0) { - for (int64_t i = p_16 * 16; i < p; ++i) { - uint8_t val_fp8 = src_fp8[i]; - float val_f32 = e4m3_to_float(val_fp8); - out_rank_ptr[i] = __float2bfloat16(val_f32 * scale); - } - } - } else { - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - for (int64_t i = tid; i < p; i += stride) { - uint8_t val_fp8 = src_fp8[i]; - float val_f32 = e4m3_to_float(val_fp8); - out_rank_ptr[i] = __float2bfloat16(val_f32 * scale); - } - } -} - -void launch_quantize( - torch::Tensor input, - torch::Tensor out_fp8, - torch::Tensor scale_tensor, - int64_t p -) { - int threads = 256; - int blocks = std::min((int)((p/16 + threads - 1) / threads), 2048); - if (blocks == 0) blocks = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - quantize_fused_kernel<<>>( - reinterpret_cast(input.data_ptr()), - out_fp8.data_ptr(), - scale_tensor.data_ptr(), - p - ); -} - -void launch_gather( - torch::Tensor peer_fp8_ptrs, - torch::Tensor peer_scale_ptrs, - torch::Tensor out_full, - int64_t p, - int world_size -) { - int threads = 256; - int blocks_x = std::min((int)((p/16 + threads - 1) / threads), 1024); - if (blocks_x == 0) blocks_x = 1; - dim3 blocks(blocks_x, world_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - dequantize_and_gather_kernel<<>>( - reinterpret_cast(peer_fp8_ptrs.data_ptr()), - reinterpret_cast(peer_scale_ptrs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(out_full.data_ptr()), - p, - world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_quantize", &launch_quantize, "Quantize BF16 to FP8 E4M3"); - m.def("launch_gather", &launch_gather, "Gather FP8 and dequantize to BF16"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fp8_allgather_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(p: int, device: torch.device): - key = (p, device) - if key in _resource_cache: - return _resource_cache[key] - - fp8_buf = symm_mem.empty(p, device=device, dtype=torch.uint8) - fp8_hdl = symm_mem.rendezvous(fp8_buf, dist.group.WORLD) - fp8_ptrs = torch.tensor(fp8_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - scale_buf = symm_mem.empty(1, device=device, dtype=torch.float32) - scale_hdl = symm_mem.rendezvous(scale_buf, dist.group.WORLD) - scale_ptrs = torch.tensor(scale_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (fp8_buf, fp8_hdl, fp8_ptrs, scale_buf, scale_hdl, scale_ptrs) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution(flat_param_shard: Tensor, amax_history: Tensor) -> tuple[Tensor, Tensor]: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - p = flat_param_shard.numel() - - # Fallback to PyTorch reference if it is not BF16 - if flat_param_shard.dtype != torch.bfloat16: - cur_abs_max = flat_param_shard.abs().max().to(torch.float32) - out_hist = torch.roll(amax_history, shifts=-1, dims=0) - out_hist[-1] = cur_abs_max.to(dtype=out_hist.dtype) - scale = out_hist.max().clamp(min=1e-12).to(torch.float32) / _FP8_E4M3_MAX - - xf = flat_param_shard.float() - qs = xf / scale - q = qs.to(torch.float8_e4m3fn) - recon = (q.float() * scale).to(dtype=flat_param_shard.dtype) - - full = torch.empty(world_size * p, dtype=flat_param_shard.dtype, device=flat_param_shard.device) - dist.all_gather_into_tensor(full, recon.contiguous()) - return full, out_hist - - # Accelerated path using Custom NVLink Gathering - ext = _get_ext() - flat_param_shard = flat_param_shard.contiguous() - fp8_buf, fp8_hdl, fp8_ptrs, scale_buf, scale_hdl, scale_ptrs = _get_resources(p, flat_param_shard.device) - - # 1. Update AMAX purely on-device using native PyTorch - cur_abs_max = flat_param_shard.abs().max().float() - updated_hist = torch.roll(amax_history, shifts=-1, dims=0) - updated_hist[-1] = cur_abs_max.to(updated_hist.dtype) - - # 2. Compute dynamic scale and deposit it onto symmetric pointer - scale = updated_hist.max().clamp(min=1e-12).float() / _FP8_E4M3_MAX - scale_buf.copy_(scale.view(-1)) - - # 3. Fast device-local quantization into our outgoing buffer - ext.launch_quantize(flat_param_shard, fp8_buf, scale_buf, p) - - # 4. Synchronize so all symmetric scales and fp8 buffers are fully written across the group - fp8_hdl.barrier(channel=0) - - # 5. Pull from peers and execute inline unpacking - full = torch.empty(world_size * p, dtype=torch.bfloat16, device=flat_param_shard.device) - ext.launch_gather(fp8_ptrs, scale_ptrs, full, p, world_size) - - # 6. Safety barrier ensuring no rank will overwrite its buffer in immediate consecutive loops - fp8_hdl.barrier(channel=0) - - return full, updated_hist - -__all__ = ["solution"] \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_cuda.py deleted file mode 100755 index 81581bf..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/55_ring_attention_cuda.py +++ /dev/null @@ -1,291 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// 1. Asynchronous P2P copy using UVA over NVLink -void uva_copy_async(int64_t dst_ptr, int64_t src_ptr, int64_t bytes, int64_t stream_ptr) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - cudaMemcpyAsync(reinterpret_cast(dst_ptr), - reinterpret_cast(src_ptr), - bytes, - cudaMemcpyDeviceToDevice, - stream); -} - -// 2. Fused Scores -> P (Softmax) + LSE tracking -__global__ void fused_scores_to_p_lse_kernel( - float* __restrict__ scores, // [B, H, S_q, S_k] - float* __restrict__ block_lse, // [B, S_q, H] - const int S_q, - const int S_k, - const int H, - const int total_rows, - const bool causal -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < total_rows) { - int bh_idx = idx / S_q; - int sq_idx = idx % S_q; - int b_idx = bh_idx / H; - int h_idx = bh_idx % H; - - float* row = scores + bh_idx * S_q * S_k + sq_idx * S_k; - - float max_val = -1e20f; - int limit = causal ? (sq_idx + 1) : S_k; - - for (int i = 0; i < limit; ++i) { - float val = row[i]; - if (val > max_val) max_val = val; - } - - float sum_exp = 0.0f; - for (int i = 0; i < limit; ++i) { - float e = expf(row[i] - max_val); - row[i] = e; - sum_exp += e; - } - - // Zero out masked components (equivalent to -inf prior to softmax) - for (int i = limit; i < S_k; ++i) { - row[i] = 0.0f; - } - - int lse_idx = b_idx * (S_q * H) + sq_idx * H + h_idx; - block_lse[lse_idx] = max_val + logf(sum_exp); - - float inv_sum = 1.0f / sum_exp; - for (int i = 0; i < limit; ++i) { - row[i] *= inv_sum; - } - } -} - -void launch_fused_scores_to_p_lse( - torch::Tensor scores, - torch::Tensor block_lse, - bool causal -) { - int B_H = scores.size(0) * scores.size(1); - int H = scores.size(1); - int S_q = scores.size(2); - int S_k = scores.size(3); - int total_rows = B_H * S_q; - - int threads = 256; - int blocks = (total_rows + threads - 1) / threads; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - fused_scores_to_p_lse_kernel<<>>( - scores.data_ptr(), - block_lse.data_ptr(), - S_q, - S_k, - H, - total_rows, - causal - ); -} - -// 3. Fused LogSumExp merge & numerically stable update -__global__ void fused_merge_out_lse_kernel( - float* __restrict__ out, - float* __restrict__ lse, - const float* __restrict__ block_out, - const float* __restrict__ block_lse, - const int D, - const int num_elements -) { - int elem_idx = blockIdx.x; // maps to [B, S_q, H] structure natively - int d_idx = threadIdx.x; - - if (elem_idx < num_elements) { - float b_lse = block_lse[elem_idx]; - float c_lse = lse[elem_idx]; - - float x_sig = b_lse - c_lse; - float sig; - if (x_sig >= 0.0f) { - sig = 1.0f / (1.0f + expf(-x_sig)); - } else { - float e = expf(x_sig); - sig = e / (1.0f + e); - } - - for (int d = d_idx; d < D; d += blockDim.x) { - float current_out = out[elem_idx * D + d]; - float b_out = block_out[elem_idx * D + d]; - out[elem_idx * D + d] = current_out - sig * (current_out - b_out); - } - - if (d_idx == 0) { - float x = c_lse - b_lse; - float log_sig; - if (x >= 0.0f) { - log_sig = -log1pf(expf(-x)); - } else { - log_sig = x - log1pf(expf(x)); - } - lse[elem_idx] = c_lse - log_sig; - } - } -} - -void launch_fused_merge_out_lse( - torch::Tensor out, - torch::Tensor lse, - torch::Tensor block_out, - torch::Tensor block_lse -) { - int num_elements = out.size(0) * out.size(1) * out.size(2); - int D = out.size(3); - - int threads = (D < 1024) ? D : 1024; - threads = (threads + 31) / 32 * 32; - - dim3 blocks(num_elements); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - fused_merge_out_lse_kernel<<>>( - out.data_ptr(), - lse.data_ptr(), - block_out.data_ptr(), - block_lse.data_ptr(), - D, - num_elements - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("uva_copy_async", &uva_copy_async, "UVA async copy via NVLink"); - m.def("launch_fused_scores_to_p_lse", &launch_fused_scores_to_p_lse, "Fused max, exp, scale and sum for Attention P"); - m.def("launch_fused_merge_out_lse", &launch_fused_merge_out_lse, "Fused out/lse tracking"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ring_attn_cp_fused", CUDA_SRC) - return _ext - -_symm_cache = {} -def get_symm_buffers(shape, dtype, device, group): - key = (tuple(shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf_k = symm_mem.empty(shape, device=device, dtype=dtype) - hdl_k = symm_mem.rendezvous(buf_k, group) - buf_v = symm_mem.empty(shape, device=device, dtype=dtype) - hdl_v = symm_mem.rendezvous(buf_v, group) - - _symm_cache[key] = (buf_k, hdl_k, buf_v, hdl_v) - return _symm_cache[key] - - -@torch.no_grad() -def solution( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** -0.5 - - ext = _get_ext() - qh = q.transpose(1, 2).float() - - # Fast path: context parallel strictly restricted to single rank - if world_size == 1: - kh = k.transpose(1, 2).float() - vh = v.transpose(1, 2).float() - scores = (torch.matmul(qh, kh.transpose(-2, -1)) * softmax_scale).contiguous() - block_lse = torch.empty(q.size(0), q.size(1), q.size(2), dtype=torch.float32, device=q.device) - ext.launch_fused_scores_to_p_lse(scores, block_lse, causal) - return torch.matmul(scores, vh).transpose(1, 2).contiguous().to(q.dtype) - - # Initialize device-side P2P layout and synchronicity mechanics - buf_k, hdl_k, buf_v, hdl_v = get_symm_buffers(k.shape, k.dtype, k.device, group) - buf_k.copy_(k) - buf_v.copy_(v) - hdl_k.barrier(channel=0) - hdl_v.barrier(channel=0) - - # Double buffers dynamically tracking the rotation step - local_k = [torch.empty_like(k), torch.empty_like(k)] - local_v = [torch.empty_like(v), torch.empty_like(v)] - local_k[0].copy_(k) - local_v[0].copy_(v) - - copy_stream = torch.cuda.Stream() - compute_stream = torch.cuda.current_stream() - events = [torch.cuda.Event() for _ in range(world_size)] - - bytes_to_copy = k.numel() * k.element_size() - out, lse = None, None - - for step in range(world_size): - # 1) Prefetch step overlapping (fetch K and V from peer UVA directly to stream) - if step + 1 < world_size: - next_buf_idx = (step + 1) % 2 - src_rank = (rank - (step + 1)) % world_size - - # Ensures target memory buffer has been totally freed by prior loops' computing workload - copy_stream.wait_stream(compute_stream) - - remote_k_ptr = int(hdl_k.buffer_ptrs[src_rank]) - remote_v_ptr = int(hdl_v.buffer_ptrs[src_rank]) - - with torch.cuda.stream(copy_stream): - ext.uva_copy_async(local_k[next_buf_idx].data_ptr(), remote_k_ptr, bytes_to_copy, copy_stream.cuda_stream) - ext.uva_copy_async(local_v[next_buf_idx].data_ptr(), remote_v_ptr, bytes_to_copy, copy_stream.cuda_stream) - events[step+1].record(copy_stream) - - # 2) Target step compute resolving - if (not causal) or step <= rank: - if step > 0: - compute_stream.wait_event(events[step]) - - curr_k = local_k[step % 2] - curr_v = local_v[step % 2] - - kh = curr_k.transpose(1, 2).float() - vh = curr_v.transpose(1, 2).float() - - # PyTorch `matmul` heavily utilizes float tensor-cores on Hopper. - scores = (torch.matmul(qh, kh.transpose(-2, -1)) * softmax_scale).contiguous() - block_lse = torch.empty(q.size(0), q.size(1), q.size(2), dtype=torch.float32, device=q.device) - - is_causal = causal and (step == 0) - - # Cuda Kernel 1: Apply mask, track max, exp scales natively on device. Overwrites scores pointer via softmax rules. - ext.launch_fused_scores_to_p_lse(scores, block_lse, is_causal) - - block_out = torch.matmul(scores, vh).transpose(1, 2).contiguous() - - if out is None: - out = block_out # Pass-through avoids re-allocation - lse = block_lse - else: - # Cuda Kernel 2: Single block-pass over output space tracking LSE - ext.launch_fused_merge_out_lse(out, lse, block_out, block_lse) - - return out.to(q.dtype) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_cuda.py deleted file mode 100755 index 3b24194..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/56_ring_attention_tp_cuda.py +++ /dev/null @@ -1,494 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// 1. CP Ring P2P Copy -// --------------------------------------------------------------------------- - -__global__ void p2p_copy_kv_kernel_128( - const int64_t remote_ptr, - void* __restrict__ next_k, - void* __restrict__ next_v, - int64_t numel_128 -) { - const uint4* src_k = reinterpret_cast(remote_ptr); - const uint4* src_v = src_k + numel_128; - - uint4* dst_k = reinterpret_cast(next_k); - uint4* dst_v = reinterpret_cast(next_v); - - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel_128) { - dst_k[idx] = src_k[idx]; - dst_v[idx] = src_v[idx]; - } -} - -__global__ void p2p_copy_kv_kernel( - const int64_t remote_ptr, - __nv_bfloat16* __restrict__ next_k, - __nv_bfloat16* __restrict__ next_v, - int64_t numel -) { - const __nv_bfloat16* src_k = reinterpret_cast(remote_ptr); - const __nv_bfloat16* src_v = src_k + numel; - - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) { - next_k[idx] = src_k[idx]; - next_v[idx] = src_v[idx]; - } -} - -void launch_p2p_copy_kv( - int64_t remote_ptr, - torch::Tensor next_k, - torch::Tensor next_v -) { - int64_t numel = next_k.numel(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (numel % 8 == 0) { - int64_t numel_128 = numel / 8; - int threads = 256; - int blocks = (numel_128 + threads - 1) / threads; - p2p_copy_kv_kernel_128<<>>( - remote_ptr, - next_k.data_ptr(), - next_v.data_ptr(), - numel_128 - ); - } else { - int threads = 256; - int blocks = (numel + threads - 1) / threads; - p2p_copy_kv_kernel<<>>( - remote_ptr, - (__nv_bfloat16*)next_k.data_ptr(), - (__nv_bfloat16*)next_v.data_ptr(), - numel - ); - } -} - -// --------------------------------------------------------------------------- -// 2. CP Ring Merge Out & LSE -// --------------------------------------------------------------------------- - -__global__ void merge_out_lse_kernel_block( - float* __restrict__ out, - float* __restrict__ lse, - const float* __restrict__ block_out, - const float* __restrict__ block_lse, - int64_t B, int64_t S, int64_t H, int64_t D -) { - int64_t lse_idx = blockIdx.x; - int64_t h = lse_idx % H; - int64_t tmp = lse_idx / H; - int64_t s = tmp % S; - int64_t b = tmp / S; - - int64_t blse_idx = b * (H * S) + h * S + s; - - __shared__ float sh_sig; - - if (threadIdx.x == 0) { - float curr_lse = lse[lse_idx]; - float b_lse = block_lse[blse_idx]; - - float max_lse = fmaxf(curr_lse, b_lse); - float exp_curr = expf(curr_lse - max_lse); - float exp_b = expf(b_lse - max_lse); - float sum_exp = exp_curr + exp_b; - - sh_sig = exp_b / sum_exp; - - // Write the new updated LSE exclusively - lse[lse_idx] = max_lse + logf(sum_exp); - } - - __syncthreads(); - - float sig = sh_sig; - int64_t out_base = lse_idx * D; - - // Process the inner dimension seamlessly via fast coalesced accesses - for (int64_t d = threadIdx.x; d < D; d += blockDim.x) { - float curr_out = out[out_base + d]; - float b_out = block_out[out_base + d]; - out[out_base + d] = curr_out - sig * (curr_out - b_out); - } -} - -void launch_merge_out_lse( - torch::Tensor out, - torch::Tensor lse, - torch::Tensor block_out, - torch::Tensor block_lse -) { - int64_t B = out.size(0); - int64_t S = out.size(1); - int64_t H = out.size(2); - int64_t D = out.size(3); - - int blocks = B * S * H; - int threads = 128; - if (D < 128) { - threads = 32; - while (threads < D) threads *= 2; - } else if (D > 128) { - threads = 256; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - merge_out_lse_kernel_block<<>>( - out.data_ptr(), - lse.data_ptr(), - block_out.data_ptr(), - block_lse.data_ptr(), - B, S, H, D - ); -} - -// --------------------------------------------------------------------------- -// 3. TP Multimem Allreduce (from switch PTX limits) -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { asm volatile("atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { asm volatile("atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 1u); -} - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { asm volatile("atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { asm volatile("atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 1u); -} - -__device__ void blockwise_barrier_relaxed(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast(remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast(local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_relaxed(send_addr); - wait_signal_relaxed(wait_addr); -} - -__device__ void blockwise_barrier_acq_rel(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast(remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast(local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4(const uint64_t* addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) { - asm volatile("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "l"(addr) : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4(const uint64_t* addr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" : : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, - int world_size, - int rank, - int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (numel_128 + (int64_t)world_size - 1) / (int64_t)world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = (int64_t)block_id * (int64_t)block_stride; block_start < numel_per_rank; block_start += (int64_t)num_programs * (int64_t)block_stride) { - const int64_t offsets = block_start + (int64_t)tid; - if (offsets >= numel_per_rank) continue; - const int64_t idx = (int64_t)rank * numel_per_rank + offsets; - uint64_t* ptrs = reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -__global__ void allreduce_bf16_fallback_kernel( - const long long* __restrict__ ptrs, - __nv_bfloat16* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __nv_bfloat16* src = (const __nv_bfloat16*)ptrs[r]; - sum += __bfloat162float(src[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, - torch::Tensor signal_pad_ptrs_tensor, - int64_t numel_128, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride -) { - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, d_signal, numel_128, world_size, rank, block_stride); -} - -void launch_allreduce_bf16_fallback( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t n -) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - allreduce_bf16_fallback_kernel<<>>( - d_ptrs, (__nv_bfloat16*)out.data_ptr(), world_size, n); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_p2p_copy_kv", &launch_p2p_copy_kv); - m.def("launch_merge_out_lse", &launch_merge_out_lse); - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16); - m.def("launch_allreduce_bf16_fallback", &launch_allreduce_bf16_fallback); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ring_attn_optim_ext", CUDA_SRC) - return _ext - -WARP_SIZE = 32 -MAX_NUM_BLOCKS = 4 -MAX_BLOCK_SIZE = 1024 -BYTES_PER_THREAD = 16 - -def _multimem_launch_config(numel: int, world_size: int): - numel_per_thread = BYTES_PER_THREAD // 2 # bf16 assumes 2 bytes - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - if num_threads < MAX_BLOCK_SIZE: - block_size = 1 - while block_size < num_threads: - block_size *= 2 - num_blocks = 1 - else: - block_size = MAX_BLOCK_SIZE - num_blocks = min((num_threads + MAX_BLOCK_SIZE - 1) // MAX_BLOCK_SIZE, MAX_NUM_BLOCKS) - return num_blocks, block_size, block_size - -_tp_cache = {} -def get_tp_symm_resources(shape, dtype, device, group): - key = (shape, dtype, device, id(group)) - if key in _tp_cache: - return _tp_cache[key] - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - out = torch.empty(shape, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, out, ptrs_tensor) - _tp_cache[key] = res - return res - -_cp_cache = {} -def get_cp_symm_resources(shape_KV, dtype, device, group): - key = (shape_KV, dtype, device, id(group)) - if key in _cp_cache: - return _cp_cache[key] - - buf = symm_mem.empty(shape_KV, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - _cp_cache[key] = (buf, hdl) - return buf, hdl - -def tp_allreduce(tensor, group): - tp_size = dist.get_world_size(group) - if tp_size == 1: - return tensor - - n = tensor.numel() - buf, hdl, out, ptrs_tensor = get_tp_symm_resources(tensor.shape, tensor.dtype, tensor.device, group) - buf.copy_(tensor) - - numel_per_thread = BYTES_PER_THREAD // 2 - if n % numel_per_thread != 0: - hdl.barrier(channel=0) - _get_ext().launch_allreduce_bf16_fallback(ptrs_tensor, out, n) - return out - - numel_128 = n // numel_per_thread - num_blocks, block_size, block_stride = _multimem_launch_config(n, tp_size) - - dist.barrier(group=group) - - multicast_ptr = int(hdl.multicast_ptr) - signal_dev = hdl.signal_pad_ptrs_dev - _get_ext().launch_multimem_allreduce_bf16( - multicast_ptr, signal_dev, numel_128, tp_size, dist.get_rank(group), - num_blocks, block_size, block_stride - ) - return buf.reshape_as(tensor).clone() - -def _local_attn( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - scale: float, causal: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - qh = q.float().transpose(1, 2) - kh = k.float().transpose(1, 2) - vh = v.float().transpose(1, 2) - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(q.size(1), k.size(1), device=q.device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - block_lse = torch.logsumexp(scores, dim=-1) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous() - return block_out, block_lse - -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - num_heads: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - tp_group: Optional[dist.ProcessGroup] = None, - cp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - Per-rank Megatron-style CP+TP ring attention forward via Device Overlapped P2P/Multimem logic. - """ - tp_group = tp_group or dist.group.WORLD - cp_group = cp_group or dist.group.WORLD - - tp_size = dist.get_world_size(tp_group) - cp_size = dist.get_world_size(cp_group) - - heads_local = num_heads // tp_size - head_dim = w_qkv.shape[0] // 3 // heads_local - if softmax_scale is None: - softmax_scale = head_dim ** -0.5 - - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - - B, S = hidden_states.shape[:2] - qkv = F.linear(hidden_states, w_qkv).view(B, S, 3, heads_local, head_dim) - q, k, v = qkv.unbind(dim=2) - - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - - out = None - lse = None - - if cp_size == 1: - block_out, block_lse = _local_attn(q, k, v, float(softmax_scale), causal) - out = block_out.to(q.dtype) - else: - # CP Buffer Setup and Remote Ring Allocation - cp_rank = dist.get_rank(cp_group) - shape_KV = (2, B, S, heads_local, head_dim) - symm_KV_buf, cp_hdl = get_cp_symm_resources(shape_KV, k.dtype, k.device, cp_group) - - symm_KV_buf[0].copy_(k) - symm_KV_buf[1].copy_(v) - cp_hdl.barrier(channel=0) - - curr_K, curr_V = k, v - next_K = torch.empty_like(k) - next_V = torch.empty_like(v) - - copy_stream = torch.cuda.Stream() - compute_stream = torch.cuda.current_stream() - copy_event = torch.cuda.Event() - compute_event = torch.cuda.Event() - - for step in range(cp_size): - # Fetch directly from the requisite peer skipping traditional ring passes - if step + 1 < cp_size: - next_source = (cp_rank - step - 1) % cp_size - remote_ptr = cp_hdl.buffer_ptrs[next_source] - - with torch.cuda.stream(copy_stream): - copy_stream.wait_event(compute_event) - _get_ext().launch_p2p_copy_kv(remote_ptr, next_K, next_V) - copy_event.record(copy_stream) - - if (not causal) or step <= cp_rank: - is_causal = causal and (step == 0) - block_out, block_lse = _local_attn(q, curr_K, curr_V, float(softmax_scale), is_causal) - - if out is None: - out = block_out.clone() - lse = block_lse.transpose(-2, -1).contiguous() - else: - _get_ext().launch_merge_out_lse(out, lse, block_out, block_lse) - - compute_event.record(compute_stream) - - if step + 1 < cp_size: - compute_stream.wait_event(copy_event) - - curr_K, next_K = next_K, curr_K - curr_V, next_V = next_V, curr_V - - out = out.to(q.dtype) - - out = F.linear(out.view(B, S, -1), w_o) - - if tp_size > 1: - out = tp_allreduce(out, tp_group) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_cuda.py deleted file mode 100755 index 61f7560..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/57_ring_attention_pp_cuda.py +++ /dev/null @@ -1,385 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -# --------------------------------------------------------------------------- -# Custom CUDA Extension for P2P via Symmetric Memory & Fused Math -# --------------------------------------------------------------------------- - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// Atomic release flag setting -__global__ void set_flag_kernel(uint32_t* addr, uint32_t val) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - asm volatile("atom.global.release.sys.exch.b32 %0, [%1], %2;" - : "=r"(val) : "l"(addr), "r"(val) : "memory"); - } -} - -// Single-thread spin wait for stream synchronization -__global__ void wait_kernel(uint32_t* flag_addr, uint32_t wait_val) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - uint32_t val; - do { - asm volatile("ld.global.acquire.sys.b32 %0, [%1];" - : "=r"(val) : "l"(flag_addr) : "memory"); - } while (val < wait_val); - } -} - -// Push K and V to peer and signal -void push_kv_and_signal( - torch::Tensor local_k, torch::Tensor local_v, - int64_t remote_k_ptr, int64_t remote_v_ptr, - int64_t remote_flag_ptr, uint32_t flag_val, - int64_t stream_ptr -) { - cudaStream_t stream = stream_ptr ? reinterpret_cast(stream_ptr) : at::cuda::getCurrentCUDAStream().stream(); - int64_t bytes = local_k.numel() * sizeof(at::BFloat16); - - cudaMemcpyAsync(reinterpret_cast(remote_k_ptr), local_k.data_ptr(), bytes, cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(reinterpret_cast(remote_v_ptr), local_v.data_ptr(), bytes, cudaMemcpyDeviceToDevice, stream); - - set_flag_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(remote_flag_ptr), flag_val); -} - -// Push PP buffer to peer and signal -void push_pp_and_signal( - torch::Tensor local_data, - int64_t remote_data_ptr, - int64_t remote_flag_ptr, uint32_t flag_val, - int64_t stream_ptr -) { - cudaStream_t stream = stream_ptr ? reinterpret_cast(stream_ptr) : at::cuda::getCurrentCUDAStream().stream(); - int64_t bytes = local_data.numel() * sizeof(at::BFloat16); - - cudaMemcpyAsync(reinterpret_cast(remote_data_ptr), local_data.data_ptr(), bytes, cudaMemcpyDeviceToDevice, stream); - - set_flag_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(remote_flag_ptr), flag_val); -} - -void wait_signal(int64_t flag_ptr, uint32_t wait_val, int64_t stream_ptr) { - cudaStream_t stream = stream_ptr ? reinterpret_cast(stream_ptr) : at::cuda::getCurrentCUDAStream().stream(); - wait_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(flag_ptr), wait_val); -} - -void set_flag_python(int64_t flag_ptr, uint32_t flag_val, int64_t stream_ptr) { - cudaStream_t stream = stream_ptr ? reinterpret_cast(stream_ptr) : at::cuda::getCurrentCUDAStream().stream(); - set_flag_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(flag_ptr), flag_val); -} - -// Fused init kernel -__global__ void init_out_lse_kernel( - float* out, float* lse, - const at::BFloat16* block_out, const float* block_lse, - int B, int S, int H, int D -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t total = (int64_t)B * S * H * D; - if (idx >= total) return; - - int d = idx % D; - int tmp = idx / D; - int h = tmp % H; - tmp = tmp / H; - int s = tmp % S; - int b = tmp / S; - - out[idx] = __bfloat162float(block_out[idx]); - if (d == 0) { - lse[(int64_t)b * (S * H) + s * H + h] = block_lse[(int64_t)b * (H * S) + h * S + s]; - } -} - -// Fused in-place sigmoid LSE and output block merge -__global__ void merge_out_lse_kernel( - float* out, float* lse, - const at::BFloat16* block_out, const float* block_lse, - int B, int S, int H, int D -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t total = (int64_t)B * S * H * D; - if (idx >= total) return; - - int d = idx % D; - int tmp = idx / D; - int h = tmp % H; - tmp = tmp / H; - int s = tmp % S; - int b = tmp / S; - - int64_t lse_idx_block = (int64_t)b * (H * S) + h * S + s; - int64_t lse_idx_out = (int64_t)b * (S * H) + s * H + h; - - float current_lse = lse[lse_idx_out]; - float b_lse = block_lse[lse_idx_block]; - - float diff = b_lse - current_lse; - float sig = 1.0f / (1.0f + expf(-diff)); - - float current_out = out[idx]; - float b_out = __bfloat162float(block_out[idx]); - - out[idx] = current_out - sig * (current_out - b_out); - - if (d == 0) { - float x = current_lse - b_lse; - float log_sig = (x >= 0) ? -log1pf(expf(-x)) : (x - log1pf(expf(x))); - lse[lse_idx_out] = current_lse - log_sig; - } -} - -void init_out_lse(torch::Tensor out, torch::Tensor lse, torch::Tensor block_out, torch::Tensor block_lse) { - int B = out.size(0); int S = out.size(1); int H = out.size(2); int D = out.size(3); - int threads = 256; int blocks = ((int64_t)B * S * H * D + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - init_out_lse_kernel<<>>( - out.data_ptr(), lse.data_ptr(), - reinterpret_cast(block_out.data_ptr()), - block_lse.data_ptr(), B, S, H, D - ); -} - -void merge_out_lse(torch::Tensor out, torch::Tensor lse, torch::Tensor block_out, torch::Tensor block_lse) { - int B = out.size(0); int S = out.size(1); int H = out.size(2); int D = out.size(3); - int threads = 256; int blocks = ((int64_t)B * S * H * D + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - merge_out_lse_kernel<<>>( - out.data_ptr(), lse.data_ptr(), - reinterpret_cast(block_out.data_ptr()), - block_lse.data_ptr(), B, S, H, D - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("push_kv_and_signal", &push_kv_and_signal); - m.def("push_pp_and_signal", &push_pp_and_signal); - m.def("wait_signal", &wait_signal); - m.def("set_flag_python", &set_flag_python); - m.def("init_out_lse", &init_out_lse); - m.def("merge_out_lse", &merge_out_lse); -} -''' - -_ext_module = None - -def _get_ext(): - global _ext_module - if _ext_module is None: - _ext_module = compile_cuda_extension("ring_attn_pp_ext", CUDA_SRC) - return _ext_module - - -_cache = {} - -def _get_resources(B, S, hidden_size, num_heads, head_dim, dtype, device): - key = (B, S, hidden_size, num_heads, head_dim, dtype, device) - if key in _cache: - return _cache[key] - - # Global buffers since symmetric memory acts on WORLD - kv_shape = (2, 2, B, S, num_heads, head_dim) - kv_buf = symm_mem.empty(kv_shape, dtype=dtype, device=device) - kv_hdl = symm_mem.rendezvous(kv_buf, group=dist.group.WORLD) - - pp_shape = (B, S, hidden_size) - pp_buf = symm_mem.empty(pp_shape, dtype=dtype, device=device) - pp_hdl = symm_mem.rendezvous(pp_buf, group=dist.group.WORLD) - - # 4 x uint32 [cp_flag, pp_flag, reserved, pp_ack] - flags_buf = symm_mem.empty((4,), dtype=torch.int32, device=device) - flags_buf.zero_() - flags_hdl = symm_mem.rendezvous(flags_buf, group=dist.group.WORLD) - - out_buf = torch.empty((B, S, num_heads, head_dim), dtype=torch.float32, device=device) - lse_buf = torch.empty((B, S, num_heads), dtype=torch.float32, device=device) - - comm_stream = torch.cuda.Stream() - - state = { - "kv_buf": kv_buf, "kv_hdl": kv_hdl, - "pp_buf": pp_buf, "pp_hdl": pp_hdl, - "flags_buf": flags_buf, "flags_hdl": flags_hdl, - "out_buf": out_buf, "lse_buf": lse_buf, - "comm_stream": comm_stream, - "cp_push_count": 0, "cp_wait_count": 0, - "pp_push_count": 0, "pp_wait_count": 0, - "pp_ack_push_count": 0, "pp_ack_wait_count": 0, - } - _cache[key] = state - return state - - -@torch.no_grad() -def solution( - hidden_states: torch.Tensor, - w_qkv: torch.Tensor, - w_o: torch.Tensor, - num_heads: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - cp_group: Optional[dist.ProcessGroup] = None, - pp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - cp_group = cp_group or dist.group.WORLD - head_dim = w_qkv.shape[0] // 3 // num_heads - scale = float(softmax_scale if softmax_scale is not None else head_dim ** -0.5) - - device = hidden_states.device - dtype = hidden_states.dtype - B, S_local, hidden_size = hidden_states.shape - - is_first, is_last = True, True - pp_rank, pp_size = 0, 1 - if pp_group is not None and dist.get_world_size(pp_group) > 1: - pp_rank = dist.get_rank(pp_group) - pp_size = dist.get_world_size(pp_group) - is_first = (pp_rank == 0) - is_last = (pp_rank == pp_size - 1) - - cp_rank = dist.get_rank(cp_group) - cp_size = dist.get_world_size(cp_group) - - _ext = _get_ext() - - # Fast path for standalone - if cp_size == 1 and pp_size == 1: - qkv = F.linear(hidden_states, w_qkv).view(B, S_local, 3, num_heads, head_dim) - q, k, v = qkv.unbind(dim=2) - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - qh, kh, vh = q.transpose(1, 2).float(), k.transpose(1, 2).float(), v.transpose(1, 2).float() - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(q.size(1), k.size(1), device=device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous() - return F.linear(block_out.to(dtype).reshape(B, S_local, -1), w_o) - - state = _get_resources(B, S_local, hidden_size, num_heads, head_dim, dtype, device) - - flags_ptrs = state["flags_hdl"].buffer_ptrs - my_flags_base = flags_ptrs[dist.get_rank()] - my_cp_flag_ptr = my_flags_base + 0 - my_pp_flag_ptr = my_flags_base + 4 - my_pp_ack_ptr = my_flags_base + 12 - pp_ptrs = state["pp_hdl"].buffer_ptrs - - pp_push_count, pp_wait_count = state["pp_push_count"], state["pp_wait_count"] - pp_ack_push_count, pp_ack_wait_count = state["pp_ack_push_count"], state["pp_ack_wait_count"] - - # 1. Pipeline-Parallel Recv - if not is_first: - pp_wait_count += 1 - _ext.wait_signal(my_pp_flag_ptr, pp_wait_count, 0) - stage_input = state["pp_buf"].clone() - - pp_ack_push_count += 1 - prev_rank = dist.get_global_rank(pp_group, (pp_rank - 1) % pp_size) - _ext.set_flag_python(flags_ptrs[prev_rank] + 12, pp_ack_push_count, 0) - else: - stage_input = hidden_states - - # 2. Local Context Parallel / Attention QKV Split - qkv = F.linear(stage_input, w_qkv).view(B, S_local, 3, num_heads, head_dim) - q, k, v = qkv.unbind(dim=2) - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - - if cp_size == 1: - qh, kh, vh = q.transpose(1, 2).float(), k.transpose(1, 2).float(), v.transpose(1, 2).float() - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - if causal: - mask = torch.triu(torch.ones(q.size(1), k.size(1), device=device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous() - ctx = block_out.to(dtype) - else: - # Pipelined CP Ring (Overlapped Communication + Computation via double buffering) - dist.barrier(group=cp_group) - - kv_buf = state["kv_buf"] - kv_buf[0, 0].copy_(k) - kv_buf[0, 1].copy_(v) - - dist.barrier(group=cp_group) - - out_buf, lse_buf, comm_stream = state["out_buf"], state["lse_buf"], state["comm_stream"] - cp_push_count, cp_wait_count = state["cp_push_count"], state["cp_wait_count"] - - global_next_cp = dist.get_global_rank(cp_group, (cp_rank + 1) % cp_size) - peer_kv_base = state["kv_hdl"].buffer_ptrs[global_next_cp] - peer_cp_flag_ptr = flags_ptrs[global_next_cp] + 0 - - buf_elements = 2 * B * S_local * num_heads * head_dim - chunk_elements = B * S_local * num_heads * head_dim - element_size = dtype.itemsize - - for step in range(cp_size): - curr_buf_idx, next_buf_idx = step % 2, (step + 1) % 2 - - if step > 0: - cp_wait_count += 1 - _ext.wait_signal(my_cp_flag_ptr, cp_wait_count, 0) - - curr_k, curr_v = kv_buf[curr_buf_idx, 0], kv_buf[curr_buf_idx, 1] - - if step + 1 != cp_size: - cp_push_count += 1 - peer_k_ptr = peer_kv_base + next_buf_idx * buf_elements * element_size - peer_v_ptr = peer_kv_base + (next_buf_idx * buf_elements + chunk_elements) * element_size - - # Copy async into peer's symmetric memory and signal - comm_stream.wait_stream(torch.cuda.current_stream()) - _ext.push_kv_and_signal( - curr_k, curr_v, peer_k_ptr, peer_v_ptr, peer_cp_flag_ptr, cp_push_count, comm_stream.cuda_stream - ) - - # Perform fused matmuls - if not (causal and step > cp_rank): - qh, kh, vh = q.transpose(1, 2).float(), curr_k.transpose(1, 2).float(), curr_v.transpose(1, 2).float() - scores = torch.matmul(qh, kh.transpose(-2, -1)) * scale - - if causal and (step == 0): - mask = torch.triu(torch.ones(q.size(1), curr_k.size(1), device=device, dtype=torch.bool), 1) - scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float("-inf")) - - block_lse = torch.logsumexp(scores, dim=-1) - block_out = torch.matmul(torch.softmax(scores, dim=-1), vh).transpose(1, 2).contiguous().to(dtype) - - if step == 0: - _ext.init_out_lse(out_buf, lse_buf, block_out, block_lse) - else: - _ext.merge_out_lse(out_buf, lse_buf, block_out, block_lse) - - state["cp_push_count"], state["cp_wait_count"] = cp_push_count, cp_wait_count - ctx = out_buf.to(dtype) - - stage_output = F.linear(ctx.reshape(B, S_local, -1), w_o) - - # 3. Pipeline-Parallel Send - if not is_last and pp_group is not None: - if pp_push_count > 0: - pp_ack_wait_count += 1 - _ext.wait_signal(my_pp_ack_ptr, pp_ack_wait_count, 0) - - pp_push_count += 1 - peer_rank = dist.get_global_rank(pp_group, (pp_rank + 1) % pp_size) - _ext.push_pp_and_signal( - stage_output, pp_ptrs[peer_rank], flags_ptrs[peer_rank] + 4, pp_push_count, 0 - ) - - state["pp_push_count"], state["pp_wait_count"] = pp_push_count, pp_wait_count - state["pp_ack_push_count"], state["pp_ack_wait_count"] = pp_ack_push_count, pp_ack_wait_count - - return stage_output \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_cuda.py deleted file mode 100755 index 9c08824..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/58_ring_attention_backward_dp_cuda.py +++ /dev/null @@ -1,377 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Tuple -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// P2P Signal and Wait Kernels (Acquire/Release Semantics) -// --------------------------------------------------------------------------- - -__global__ void p2p_signal_kernel(uint32_t* addr) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - uint32_t tmp; - do { - asm volatile("atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0); - } -} - -__global__ void p2p_wait_kernel(uint32_t* addr) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - uint32_t tmp; - do { - asm volatile("atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1); - } -} - -void p2p_signal(int64_t addr, int64_t stream_ptr) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - p2p_signal_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(addr)); -} - -void p2p_wait(int64_t addr, int64_t stream_ptr) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - p2p_wait_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(addr)); -} - -void p2p_memcpy_async(int64_t dst_ptr, int64_t src_ptr, int64_t size_bytes, int64_t stream_ptr) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - C10_CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast(dst_ptr), - reinterpret_cast(src_ptr), - size_bytes, cudaMemcpyDefault, stream)); -} - -// --------------------------------------------------------------------------- -// Fused Elementwise Kernel for Attention Backward -// --------------------------------------------------------------------------- - -__global__ void fused_elementwise_bf16_kernel( - __nv_bfloat16* __restrict__ scores, - __nv_bfloat16* __restrict__ dP, - const float* __restrict__ lse, - const float* __restrict__ row_dot, - int BH, int Sq, int Sk, bool causal, float scale -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = BH * Sq * Sk; - if (idx < total) { - int sk_idx = idx % Sk; - int tmp = idx / Sk; - int sq_idx = tmp % Sq; - int bh_idx = tmp / Sq; - - float score = __bfloat162float(scores[idx]) * scale; - if (causal && sq_idx < sk_idx) { - score = -INFINITY; - } - - float cur_lse = lse[bh_idx * Sq + sq_idx]; - float prob = expf(score - cur_lse); - - float dp_val = __bfloat162float(dP[idx]); - float rd_val = row_dot[bh_idx * Sq + sq_idx]; - float ds_val = prob * (dp_val - rd_val); - - scores[idx] = __float2bfloat16(prob); - dP[idx] = __float2bfloat16(ds_val * scale); - } -} - -void fused_elementwise_bf16( - torch::Tensor scores, torch::Tensor dP, - torch::Tensor lse, torch::Tensor row_dot, - int BH, int Sq, int Sk, bool causal, float scale, int64_t stream_ptr -) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - int total = BH * Sq * Sk; - int threads = 256; - int blocks = (total + threads - 1) / threads; - fused_elementwise_bf16_kernel<<>>( - reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(dP.data_ptr()), - lse.data_ptr(), - row_dot.data_ptr(), - BH, Sq, Sk, causal, scale - ); -} - -// --------------------------------------------------------------------------- -// Multimem DP All-Reduce (Hopper) -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void blockwise_barrier_relaxed( - const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint32_t* send_addr = reinterpret_cast(signal_pad_ptrs[flat_tid] + block_id * world_size + rank); - uint32_t* wait_addr = reinterpret_cast(signal_pad_ptrs[rank] + block_id * world_size + flat_tid); - - uint32_t tmp; - do { asm volatile("atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(send_addr) : "memory"); } while (tmp != 0u); - do { asm volatile("atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(wait_addr) : "memory"); } while (tmp != 1u); -} - -__device__ __forceinline__ void blockwise_barrier_acq_rel( - const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint32_t* send_addr = reinterpret_cast(signal_pad_ptrs[flat_tid] + block_id * world_size + rank); - uint32_t* wait_addr = reinterpret_cast(signal_pad_ptrs[rank] + block_id * world_size + flat_tid); - - uint32_t tmp; - do { asm volatile("atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(send_addr) : "memory"); } while (tmp != 0u); - do { asm volatile("atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(wait_addr) : "memory"); } while (tmp != 1u); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, int world_size, int rank, int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (numel_128 + world_size - 1) / world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = block_id * block_stride; block_start < numel_per_rank; block_start += num_programs * block_stride) { - const int64_t offsets = block_start + tid; - if (offsets >= numel_per_rank) continue; - uint64_t* ptrs = reinterpret_cast(multicast_base) + (rank * numel_per_rank + offsets) * 2; - uint32_t r0, r1, r2, r3; - asm volatile("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "l"(ptrs) : "memory"); - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" : : "l"(ptrs), "r"(r0), "r"(r1), "r"(r2), "r"(r3) : "memory"); - } - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -__global__ void allreduce_bf16_kernel(const long long* __restrict__ ptrs, __nv_bfloat16* __restrict__ out, int world_size, int64_t n) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - sum += __bfloat162float(((const __nv_bfloat16*)ptrs[r])[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, torch::Tensor signal_pad_ptrs_tensor, int64_t numel_128, - int world_size, int rank, int num_blocks, int block_size, int block_stride, int64_t stream_ptr -) { - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = reinterpret_cast(stream_ptr); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, d_signal, numel_128, world_size, rank, block_stride); -} - -void launch_allreduce(torch::Tensor ptrs_tensor, torch::Tensor out, int64_t n, int64_t stream_ptr) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - int threads = 512; - int blocks = min((int)((n + threads - 1) / threads), 65535); - cudaStream_t stream = reinterpret_cast(stream_ptr); - allreduce_bf16_kernel<<>>(d_ptrs, reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), world_size, n); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("p2p_signal", &p2p_signal); - m.def("p2p_wait", &p2p_wait); - m.def("p2p_memcpy_async", &p2p_memcpy_async); - m.def("fused_elementwise_bf16", &fused_elementwise_bf16); - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16); - m.def("launch_allreduce", &launch_allreduce); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("ring_attention_bwd_ext", CUDA_SRC) - return _ext - -_cp_resource_cache = {} -def get_cp_resources(B, S, H, D, dtype, device, cp_group): - key = (B, S, H, D, dtype, device) - if key in _cp_resource_cache: - return _cp_resource_cache[key] - - big_buf = symm_mem.empty((4, 2, B, S, H, D), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(big_buf, group=cp_group) - ready_buf = symm_mem.empty((2,), dtype=torch.int32, device=device).zero_() - ready_hdl = symm_mem.rendezvous(ready_buf, group=cp_group) - done_buf = symm_mem.empty((2,), dtype=torch.int32, device=device).fill_(1) - done_hdl = symm_mem.rendezvous(done_buf, group=cp_group) - comm_stream = torch.cuda.Stream() - - res = (big_buf, hdl, ready_buf, ready_hdl, done_buf, done_hdl, comm_stream) - _cp_resource_cache[key] = res - return res - -_dp_resource_cache = {} -def allreduce_dp(tensor, dp_group): - n = tensor.numel() - key = (n, tensor.dtype, tensor.device) - if key not in _dp_resource_cache: - buf = symm_mem.empty(n, dtype=tensor.dtype, device=tensor.device) - hdl = symm_mem.rendezvous(buf, group=dp_group) - out = torch.empty(n, dtype=tensor.dtype, device=tensor.device) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=tensor.device, dtype=torch.int64) - _dp_resource_cache[key] = (buf, hdl, out, ptrs_tensor) - - buf, hdl, out, ptrs_tensor = _dp_resource_cache[key] - buf.copy_(tensor) - - numel_per_thread = 8 - if n % numel_per_thread != 0: - hdl.barrier(channel=0) - _get_ext().launch_allreduce(ptrs_tensor, out, n, torch.cuda.current_stream().cuda_stream) - return out - - numel_128 = n // numel_per_thread - num_threads = (numel_128 + hdl.world_size - 1) // hdl.world_size - if num_threads < 1024: - block_size = 1 - while block_size < num_threads: block_size *= 2 - num_blocks = 1 - else: - block_size = 1024 - num_blocks = min((num_threads + 1023) // 1024, 4) - - dist.barrier(group=dp_group) - _get_ext().launch_multimem_allreduce_bf16( - int(hdl.multicast_ptr), hdl.signal_pad_ptrs_dev, numel_128, hdl.world_size, hdl.rank, - num_blocks, block_size, block_size, torch.cuda.current_stream().cuda_stream - ) - return buf.clone() - -def compute_local_attn(q, k, v, dout, out, lse, scale, causal, row_dot): - qh = q.transpose(1, 2) - kh = k.transpose(1, 2) - vh = v.transpose(1, 2) - doh = dout.transpose(1, 2) - - scores = torch.matmul(qh, kh.transpose(-2, -1)) - dP = torch.matmul(doh, vh.transpose(-2, -1)) - - B, H, Sq, D = qh.shape - Sk = kh.shape[2] - - _get_ext().fused_elementwise_bf16( - scores, dP, lse, row_dot, B*H, Sq, Sk, causal, scale, torch.cuda.current_stream().cuda_stream - ) - - dQ = torch.matmul(dP, kh) - dK = torch.matmul(dP.transpose(-2, -1), qh) - dV = torch.matmul(scores.transpose(-2, -1), doh) - return dQ.transpose(1, 2).contiguous(), dK.transpose(1, 2).contiguous(), dV.transpose(1, 2).contiguous() - -@torch.no_grad() -def solution( - dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor, softmax_lse: torch.Tensor, softmax_scale: Optional[float] = None, - causal: bool = False, cp_group: Optional[dist.ProcessGroup] = None, dp_group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - cp_group = cp_group or dist.group.WORLD - if softmax_scale is None: - softmax_scale = q.shape[-1] ** -0.5 - - C = dist.get_world_size(cp_group) - row_dot = (dout.float() * out.float()).sum(dim=-1, keepdim=True).transpose(1, 2).contiguous() - lse = softmax_lse.contiguous() - - if C == 1: - dq, dk, dv = compute_local_attn(q, k, v, dout, out, lse, softmax_scale, causal, row_dot) - else: - rank = dist.get_rank(cp_group) - B, S, H, D = q.shape - big_buf, hdl, ready_buf, ready_hdl, done_buf, done_hdl, comm_stream = get_cp_resources(B, S, H, D, q.dtype, q.device, cp_group) - - global_next = dist.get_global_rank(cp_group, (rank + 1) % C) - global_prev = dist.get_global_rank(cp_group, (rank - 1) % C) - - peer_next_base_ptr = int(hdl.buffer_ptrs[global_next]) - peer_next_ready_ptr = int(ready_hdl.buffer_ptrs[global_next]) - peer_prev_done_ptr = int(done_hdl.buffer_ptrs[global_prev]) - size_per_buf = B * S * H * D * q.element_size() - compute_stream = torch.cuda.current_stream() - - dq, dk_curr, dv_curr = None, None, None - - for i in range(C): - buf_idx = i % 2 - next_buf_idx = (i + 1) % 2 - - if i > 0: - _get_ext().p2p_wait(ready_buf.data_ptr() + buf_idx * 4, compute_stream.cuda_stream) - - k_curr = q if i == 0 else big_buf[0, buf_idx] - v_curr = v if i == 0 else big_buf[1, buf_idx] - - if i <= rank or not causal: - block_dq, block_dk, block_dv = compute_local_attn(q, k_curr, v_curr, dout, out, lse, softmax_scale, causal and (i == 0), row_dot) - if i == 0: - dq, dk_curr, dv_curr = block_dq, block_dk, block_dv - else: - dq.add_(block_dq) - dk_curr = block_dk.add_(big_buf[2, buf_idx]) - dv_curr = block_dv.add_(big_buf[3, buf_idx]) - else: - if i > 0: - dk_curr = big_buf[2, buf_idx] - dv_curr = big_buf[3, buf_idx] - - if i > 0: - _get_ext().p2p_signal(peer_prev_done_ptr + buf_idx * 4, compute_stream.cuda_stream) - - with torch.cuda.stream(comm_stream): - comm_stream.wait_stream(compute_stream) - _get_ext().p2p_wait(done_buf.data_ptr() + next_buf_idx * 4, comm_stream.cuda_stream) - - if i + 1 < C: - _get_ext().p2p_memcpy_async(peer_next_base_ptr + (next_buf_idx) * size_per_buf, k_curr.data_ptr(), size_per_buf, comm_stream.cuda_stream) - _get_ext().p2p_memcpy_async(peer_next_base_ptr + (2 + next_buf_idx) * size_per_buf, v_curr.data_ptr(), size_per_buf, comm_stream.cuda_stream) - - _get_ext().p2p_memcpy_async(peer_next_base_ptr + (4 + next_buf_idx) * size_per_buf, dk_curr.data_ptr(), size_per_buf, comm_stream.cuda_stream) - _get_ext().p2p_memcpy_async(peer_next_base_ptr + (6 + next_buf_idx) * size_per_buf, dv_curr.data_ptr(), size_per_buf, comm_stream.cuda_stream) - - _get_ext().p2p_signal(peer_next_ready_ptr + next_buf_idx * 4, comm_stream.cuda_stream) - - compute_stream.wait_stream(comm_stream) - final_buf_idx = C % 2 - _get_ext().p2p_wait(ready_buf.data_ptr() + final_buf_idx * 4, compute_stream.cuda_stream) - dk = big_buf[2, final_buf_idx].clone() - dv = big_buf[3, final_buf_idx].clone() - _get_ext().p2p_signal(peer_prev_done_ptr + final_buf_idx * 4, compute_stream.cuda_stream) - - if dp_group is not None and dist.get_world_size(dp_group) > 1: - dp_size = dist.get_world_size(dp_group) - packed = torch.cat([dq.flatten(), dk.flatten(), dv.flatten()]) - packed = allreduce_dp(packed, dp_group) - packed.div_(dp_size) - - split_size = dq.numel() - dq = packed[:split_size].view(dq.shape) - dk = packed[split_size:2*split_size].view(dk.shape) - dv = packed[2*split_size:].view(dv.shape) - - return dq, dk, dv \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_cuda.py deleted file mode 100755 index 3863670..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/59_openclip_contrastive_loss_cuda.py +++ /dev/null @@ -1,270 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - - -CUDA_SRC = r''' -#include -#include -#include -#include - -// Vectorized fast copy using 128-bit loads for bandwidth optimization over NVLink -__global__ void copy_128bit_kernel( - uint4* __restrict__ dst, - const uint4* __restrict__ src, - int64_t n_128bit -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n_128bit; idx += (int64_t)gridDim.x * blockDim.x) { - dst[idx] = src[idx]; - } -} - -// Fallback generic elementwise copy -template -__global__ void copy_generic_kernel( - scalar_t* __restrict__ dst, - const scalar_t* __restrict__ src, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - dst[idx] = src[idx]; - } -} - -void async_copy( - torch::Tensor dst, - int64_t src_ptr, - int64_t n_elements, - int64_t stream_ptr -) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - int element_size = dst.element_size(); - - // Check 16-byte alignment to enable vectorized 128-bit transfer - if ((n_elements * element_size) % 16 == 0) { - int64_t n_128 = (n_elements * element_size) / 16; - int threads = 256; - int blocks = std::min(65535, (n_128 + threads - 1) / threads); - copy_128bit_kernel<<>>( - reinterpret_cast(dst.data_ptr()), - reinterpret_cast(src_ptr), - n_128 - ); - } else { - int threads = 256; - int blocks = std::min(65535, (n_elements + threads - 1) / threads); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dst.scalar_type(), "async_copy_generic", [&] { - copy_generic_kernel<<>>( - dst.data_ptr(), - reinterpret_cast(src_ptr), - n_elements - ); - }); - } -} - -// Fused kernel calculating the stable softplus and reducing to a scalar loss. -template -__global__ void siglip_loss_forward_kernel( - const scalar_t* __restrict__ logits, - float scale, - float bias, - float* __restrict__ loss_out, - int batch_size, - bool is_local -) { - int64_t total_elements = (int64_t)batch_size * batch_size; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - - float local_sum = 0.0f; - for (int64_t i = idx; i < total_elements; i += (int64_t)gridDim.x * blockDim.x) { - int r = i / batch_size; - int c = i % batch_size; - - float x = static_cast(logits[i]); - x = x * scale + bias; - - float label = -1.0f; - if (is_local && r == c) { - label = 1.0f; // Match on local batch diagonal - } - - float z = label * x; - // stable softplus to represent -logsigmoid - float neg_z = -z; - float max_val = neg_z > 0.0f ? neg_z : 0.0f; - float term = expf(-fabsf(neg_z)); - float loss_val = max_val + log1pf(term); - - local_sum += loss_val; - } - - // Warp-level reduction - unsigned int mask = 0xffffffff; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_sum += __shfl_down_sync(mask, local_sum, offset); - } - - // Block-level reduction through shared memory - __shared__ float shared_sum[32]; - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - if (lane_id == 0) { - shared_sum[warp_id] = local_sum; - } - __syncthreads(); - - if (warp_id == 0) { - float val = (lane_id < (blockDim.x / 32)) ? shared_sum[lane_id] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(mask, val, offset); - } - if (lane_id == 0) { - atomicAdd(loss_out, val / static_cast(batch_size)); - } - } -} - -void siglip_loss_forward( - torch::Tensor logits, - float scale, - float bias, - torch::Tensor loss_out, - int batch_size, - bool is_local -) { - int threads = 256; - int64_t total = (int64_t)batch_size * batch_size; - int blocks = std::min(1024, (total + threads - 1) / threads); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "siglip_loss_forward", [&] { - siglip_loss_forward_kernel<<>>( - logits.data_ptr(), - scale, - bias, - loss_out.data_ptr(), - batch_size, - is_local - ); - }); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("async_copy", &async_copy, "Async vector copy from device memory over UVA"); - m.def("siglip_loss_forward", &siglip_loss_forward, "Fused SigLIP loss kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("siglip_overlap_ext", CUDA_SRC) - return _ext - -def _init_ext(group: Optional[dist.ProcessGroup] = None): - global _ext - if _ext is None: - if dist.is_initialized(): - if dist.get_rank(group) == 0: - _get_ext() - dist.barrier(group) - _get_ext() - -_symm_cache = {} -def _get_symm_state(shape, dtype, device, group): - key = (shape, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - _symm_cache[key] = (buf, hdl) - return buf, hdl - - -@torch.no_grad() -def solution( - image_features: torch.Tensor, - text_features: torch.Tensor, - logit_scale: float, - logit_bias: float = 0.0, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or (dist.group.WORLD if dist.is_initialized() else None) - _init_ext(group) - - world_size = dist.get_world_size(group) if dist.is_initialized() else 1 - rank = dist.get_rank(group) if dist.is_initialized() else 0 - - batch_size = image_features.size(0) - n_elements = text_features.numel() - logit_scale_f = float(logit_scale) - logit_bias_f = float(logit_bias) - - loss_scalar = torch.zeros(1, dtype=torch.float32, device=image_features.device) - - # Fast path for single GPU deployments without ring/overlap - if world_size == 1: - logits0 = torch.matmul(image_features, text_features.T) - _get_ext().siglip_loss_forward(logits0.contiguous(), logit_scale_f, logit_bias_f, loss_scalar, batch_size, True) - return loss_scalar.to(dtype=image_features.dtype).squeeze() - - buf, hdl = _get_symm_state(text_features.shape, text_features.dtype, text_features.device, group) - - # Broadcast current rank's text chunk to symm memory segment - buf.copy_(text_features) - hdl.barrier(channel=0) - - # Establish dual buffers and streams for NVLink overlap prefetching pipeline - bufA = torch.empty_like(text_features) - bufB = torch.empty_like(text_features) - stream_copy = torch.cuda.Stream() - - event_copy_done = [torch.cuda.Event(enable_timing=False) for _ in range(world_size)] - event_compute_done = [torch.cuda.Event(enable_timing=False) for _ in range(world_size)] - - # Step 0: Pre-fetch for step 1 while computing local - p1 = (rank + 1) % world_size - _get_ext().async_copy(bufA, int(hdl.buffer_ptrs[p1]), n_elements, stream_copy.cuda_stream) - event_copy_done[1].record(stream_copy) - - # Step 0: Calculate initial local matching block - logits0 = torch.matmul(image_features, buf.T) - _get_ext().siglip_loss_forward(logits0.contiguous(), logit_scale_f, logit_bias_f, loss_scalar, batch_size, True) - event_compute_done[0].record() - - # Step 1 ... World_size-1: Progress down ring sequence iteratively - for s in range(1, world_size): - curr_buf = bufA if (s % 2) != 0 else bufB - next_buf = bufB if (s % 2) != 0 else bufA - - # Pre-fetch the next peer sequentially into the free secondary buffer - if s < world_size - 1: - next_p = (rank + s + 1) % world_size - # Keep copy stream waiting until primary stream safely consumes this secondary buffer (from compute step s-1) - stream_copy.wait_event(event_compute_done[s-1]) - _get_ext().async_copy(next_buf, int(hdl.buffer_ptrs[next_p]), n_elements, stream_copy.cuda_stream) - event_copy_done[s+1].record(stream_copy) - - # Guarantee prefetching to curr_buf has thoroughly settled before matrix multiplication - torch.cuda.current_stream().wait_event(event_copy_done[s]) - - logits = torch.matmul(image_features, curr_buf.T) - _get_ext().siglip_loss_forward(logits.contiguous(), logit_scale_f, logit_bias_f, loss_scalar, batch_size, False) - event_compute_done[s].record() - - # Hardware block until all parallel reading passes and ensures symm cache coherence - hdl.barrier(channel=0) - - return loss_scalar.to(dtype=image_features.dtype).squeeze() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_cuda.py deleted file mode 100755 index f3e2500..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/5_scatter_cuda.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Strategy: -- **Device-Side Push:** Avoids repeated NCCL/host collective calls by allocating symmetric memory buffers (`symm_mem`) on all ranks, enabling direct peer-to-peer memory mappings via UVA. The source rank executes a single custom CUDA kernel that pushes data directly to all peers' symmetric buffers simultaneously. -- **Compute–Communication Overlap:** By launching a grid-stride kernel spanning multiple SMs (one grid Y-dimension per peer), the data push maximally exploits the bidirectional NVLink bandwidth. The copy dynamically falls back from 128-bit vectorization to granular access, ensuring high throughput without blocking host CPU, masked behind stream-ordered barriers (`hdl.barrier()`). -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void scatter_push_kernel( - const char* __restrict__ input, - const uint64_t* __restrict__ peer_ptrs, - int64_t chunk_bytes, - int world_size, - int src_rank -) { - int rank_idx = blockIdx.y; - if (rank_idx >= world_size || rank_idx == src_rank) return; - - char* out = reinterpret_cast(peer_ptrs[rank_idx]); - const char* in = input + rank_idx * chunk_bytes; - - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - // Calculate maximum alignment for vectorization - int align = 1; - if (((reinterpret_cast(in) % 16) == 0) && ((reinterpret_cast(out) % 16) == 0)) align = 16; - else if (((reinterpret_cast(in) % 8) == 0) && ((reinterpret_cast(out) % 8) == 0)) align = 8; - else if (((reinterpret_cast(in) % 4) == 0) && ((reinterpret_cast(out) % 4) == 0)) align = 4; - else if (((reinterpret_cast(in) % 2) == 0) && ((reinterpret_cast(out) % 2) == 0)) align = 2; - - if (align == 16) { - int64_t num_vec = chunk_bytes / 16; - const uint4* in_vec = reinterpret_cast(in); - uint4* out_vec = reinterpret_cast(out); - for (int64_t i = tid; i < num_vec; i += stride) { - out_vec[i] = in_vec[i]; - } - if (tid == 0) { - for (int64_t i = num_vec * 16; i < chunk_bytes; i++) out[i] = in[i]; - } - } else if (align == 8) { - int64_t num_vec = chunk_bytes / 8; - const uint2* in_vec = reinterpret_cast(in); - uint2* out_vec = reinterpret_cast(out); - for (int64_t i = tid; i < num_vec; i += stride) { - out_vec[i] = in_vec[i]; - } - if (tid == 0) { - for (int64_t i = num_vec * 8; i < chunk_bytes; i++) out[i] = in[i]; - } - } else if (align == 4) { - int64_t num_vec = chunk_bytes / 4; - const uint32_t* in_vec = reinterpret_cast(in); - uint32_t* out_vec = reinterpret_cast(out); - for (int64_t i = tid; i < num_vec; i += stride) { - out_vec[i] = in_vec[i]; - } - if (tid == 0) { - for (int64_t i = num_vec * 4; i < chunk_bytes; i++) out[i] = in[i]; - } - } else if (align == 2) { - int64_t num_vec = chunk_bytes / 2; - const uint16_t* in_vec = reinterpret_cast(in); - uint16_t* out_vec = reinterpret_cast(out); - for (int64_t i = tid; i < num_vec; i += stride) { - out_vec[i] = in_vec[i]; - } - if (tid == 0) { - for (int64_t i = num_vec * 2; i < chunk_bytes; i++) out[i] = in[i]; - } - } else { - for (int64_t i = tid; i < chunk_bytes; i += stride) { - out[i] = in[i]; - } - } -} - -void launch_scatter_push( - torch::Tensor input, - torch::Tensor peer_ptrs_tensor, - int64_t chunk_bytes, - int world_size, - int src_rank -) { - const char* d_input = reinterpret_cast(input.data_ptr()); - const uint64_t* d_peer_ptrs = reinterpret_cast(peer_ptrs_tensor.data_ptr()); - - int threads = 256; - int64_t max_vec = (chunk_bytes + 15) / 16; - int64_t blocks_per_rank = (max_vec + threads - 1) / threads; - if (blocks_per_rank > 256) blocks_per_rank = 256; - if (blocks_per_rank < 1) blocks_per_rank = 1; - - dim3 blocks(static_cast(blocks_per_rank), static_cast(world_size)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - scatter_push_kernel<<>>( - d_input, d_peer_ptrs, chunk_bytes, world_size, src_rank - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_scatter_push", &launch_scatter_push, "Scatter push kernel via symmetric memory UVA"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - if dist.is_initialized(): - if dist.get_rank() == 0: - _ext = compile_cuda_extension("scatter_cuda_ext", CUDA_SRC) - dist.barrier() - _ext = compile_cuda_extension("scatter_cuda_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(chunk_shape, dtype, device): - key = (tuple(chunk_shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(chunk_shape, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_cache[key] = (buf, hdl, ptrs) - return buf, hdl, ptrs - -_channel = 0 -def _next_channel(): - global _channel - c = _channel - _channel = (_channel + 1) % 256 - return c - - -@torch.no_grad() -def solution( - tensor: torch.Tensor, - src: int = 0, -) -> torch.Tensor: - if not dist.is_initialized(): - return tensor.clone() - - world_size = dist.get_world_size() - rank = dist.get_rank() - - # Handle potentially empty tensors directly - if tensor.numel() == 0: - if rank == src: - return tensor[src].clone() - else: - return tensor.clone() - - tensor = tensor.contiguous() - - if rank == src: - assert tensor.shape[0] == world_size, f"Source tensor must have {world_size} chunks" - chunk_shape = tensor.shape[1:] - else: - chunk_shape = tensor.shape - - _get_ext() # Ensure extension is loaded - - buf, hdl, ptrs = _get_symm_state(chunk_shape, tensor.dtype, tensor.device) - chunk_bytes = buf.numel() * buf.element_size() - - # Synchronization 1: Ensure peers have finished consuming the symmetric buffer - # from any previous operations before source starts overwriting it. - c1 = _next_channel() - hdl.barrier(channel=c1) - - if rank == src: - _get_ext().launch_scatter_push(tensor, ptrs, chunk_bytes, world_size, src) - - # Synchronization 2: Block peers from reading until source has completed pushing - # data to their symmetric buffer partitions. - c2 = _next_channel() - hdl.barrier(channel=c2) - - if rank == src: - out = tensor[src].clone() - else: - out = buf.clone() - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_cuda.py deleted file mode 100755 index 40f22f2..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/60_physicsnemo_distributed_rfft_cuda.py +++ /dev/null @@ -1,206 +0,0 @@ -import math -from typing import Optional, Sequence - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// A custom push kernel that directly scatters local FFT results into -// the remote, concatenated target tensors via NVLink symmetric memory. -template -__global__ void push_kernel( - const T* __restrict__ x1, - const uint64_t* __restrict__ remote_ptrs, - uint32_t N0, uint32_t N1, uint32_t N2, uint32_t N3, uint32_t N4, - uint32_t chunk0, uint32_t D1, uint32_t D1_mul_W, - uint32_t my_rank, - bool dim0_lt_dim1, - uint64_t total_elements -) { - for (uint64_t flat_idx = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; - flat_idx < total_elements; - flat_idx += (uint64_t)gridDim.x * blockDim.x) { - - // N-Dimensional Indexing mapped over a flattened generic 5D abstraction - uint64_t temp = flat_idx; - uint32_t i4 = temp % N4; temp /= N4; - uint32_t i3 = temp % N3; temp /= N3; - uint32_t i2 = temp % N2; temp /= N2; - uint32_t i1 = temp % N1; temp /= N1; - uint32_t i0 = temp; - - uint32_t r; - uint64_t out_flat; - - if (dim0_lt_dim1) { - r = i1 / chunk0; - uint32_t out_i1 = i1 % chunk0; - uint32_t out_i3 = my_rank * D1 + i3; - // Native concatenated flat offset calculation - out_flat = ((( (uint64_t)i0 * chunk0 + out_i1 ) * N2 + i2 ) * D1_mul_W + out_i3 ) * N4 + i4; - } else { - r = i3 / chunk0; - uint32_t out_i3 = i3 % chunk0; - uint32_t out_i1 = my_rank * D1 + i1; - // Native concatenated flat offset calculation - out_flat = ((( (uint64_t)i0 * D1_mul_W + out_i1 ) * N2 + i2 ) * chunk0 + out_i3 ) * N4 + i4; - } - - // Direct device-side scatter across NVLink into remote peer memory - T* dest = (T*)remote_ptrs[r]; - dest[out_flat] = x1[flat_idx]; - } -} - -void launch_push_kernel( - torch::Tensor x1, - torch::Tensor remote_ptrs, - int64_t N0, int64_t N1, int64_t N2, int64_t N3, int64_t N4, - int64_t chunk0, int64_t D1, int64_t D1_mul_W, - int my_rank, - bool dim0_lt_dim1 -) { - uint64_t total_elements = N0 * N1 * N2 * N3 * N4; - if (total_elements == 0) return; - - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - if (blocks > 1024 * 64) blocks = 1024 * 64; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const uint64_t* ptrs = (const uint64_t*)remote_ptrs.data_ptr(); - - // Copy vectorization mappings depending on standard complex shapes - if (x1.element_size() == 8) { - push_kernel<<>>( - (const int64_t*)x1.data_ptr(), ptrs, - N0, N1, N2, N3, N4, chunk0, D1, D1_mul_W, - my_rank, dim0_lt_dim1, total_elements - ); - } else if (x1.element_size() == 16) { - push_kernel<<>>( - (const int4*)x1.data_ptr(), ptrs, - N0, N1, N2, N3, N4, chunk0, D1, D1_mul_W, - my_rank, dim0_lt_dim1, total_elements - ); - } else { - TORCH_CHECK(false, "Unsupported element size for direct push"); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push_kernel", &launch_push_kernel, "Custom symmetric push kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("physicsnemo_fft_push_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(shape, dtype, device, group): - key = (tuple(shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return buf, hdl, ptrs_tensor - -def _truncate(tensor: torch.Tensor, dim: int, size: int) -> torch.Tensor: - """Return a contiguous slice tensor[..., :size, ...] along dim.""" - slices = [slice(None)] * tensor.ndim - slices[dim % tensor.ndim] = slice(0, size) - return tensor[tuple(slices)].contiguous() - -@torch.no_grad() -def solution( - x: torch.Tensor, - s: Sequence[int], - dim: Sequence[int], - norm: str = "ortho", - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - my_rank = dist.get_rank(group) - - dim0, dim1 = int(dim[0]), int(dim[1]) - ndim = x.ndim - - # Handle negative dimension indexing - dim0 = dim0 if dim0 >= 0 else dim0 + ndim - dim1 = dim1 if dim1 >= 0 else dim1 + ndim - - # 1. Transform the replicated spatial dimension. - # Output of fft for float32/bfloat16 returns complex64 (8 bytes per element) - x1 = torch.fft.fft(x, n=int(s[0]), dim=dim0, norm=norm).contiguous() - - if world_size == 1: - x1_tran = x1 - else: - # 2. All-to-all transpose fused into Custom Symmetric Memory Push - shape_in = list(x1.shape) - D0 = shape_in[dim0] - D1 = shape_in[dim1] - chunk0 = D0 // world_size - - shape_tran = list(shape_in) - shape_tran[dim0] = chunk0 - shape_tran[dim1] = D1 * world_size - - buf, hdl, ptrs_tensor = _get_symm_state(shape_tran, x1.dtype, x1.device, group) - - # Convert varying arbitrary N-dimensional space to strict 5D configuration abstraction - d_min = min(dim0, dim1) - d_max = max(dim0, dim1) - - N0 = math.prod(shape_in[:d_min]) - N1 = shape_in[d_min] - N2 = math.prod(shape_in[d_min+1:d_max]) - N3 = shape_in[d_max] - N4 = math.prod(shape_in[d_max+1:]) - - dim0_lt_dim1 = (dim0 < dim1) - - # Build extension serially from Rank 0 locally to prevent compilation race conditions - if my_rank == 0: - _get_ext() - dist.barrier(group=group) - - # Ensure local symmetric buffer is clear and ready to receive writes - hdl.barrier(channel=0) - - _get_ext().launch_push_kernel( - x1, ptrs_tensor, - N0, N1, N2, N3, N4, - chunk0, D1, D1 * world_size, - my_rank, dim0_lt_dim1 - ) - - # Stream sync followed by a dist barrier guarantees all peers have finished writing safely - torch.cuda.current_stream().synchronize() - dist.barrier(group=group) - - x1_tran = buf - - # 3. Transform the now-replicated second dimension natively - x2 = torch.fft.fft(x1_tran, n=int(s[1]), dim=dim1, norm=norm) - - # 4. Truncate returning real-input half spectrum shape mapping - return _truncate(x2, dim1, x2.shape[dim1] // 2 + 1) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_cuda.py deleted file mode 100755 index e4f4a6f..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/61_physicsnemo_distributed_irfft_cuda.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import Optional, Sequence - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void cast_complex_to_bf162_kernel( - const float2* __restrict__ in, - __nv_bfloat162* __restrict__ out, - int64_t numel -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) { - out[idx] = __float22bfloat162_rn(in[idx]); - } -} - -__global__ void conj_pad_2d_bf16_kernel( - const uint64_t* __restrict__ symm_ptrs, - float2* __restrict__ out, - int B, int N0_local, int N1_half, int N1, - int rank, int world_size -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t total = (int64_t)B * N0_local * N1; - if (idx >= total) return; - - int i1 = idx % N1; - int tmp = idx / N1; - int i0_local = tmp % N0_local; - int b = tmp / N0_local; - - int N0 = N0_local * world_size; - - if (i1 < N1_half) { - const __nv_bfloat162* local_x = reinterpret_cast(symm_ptrs[rank]); - __nv_bfloat162 val = local_x[(b * N0_local + i0_local) * N1_half + i1]; - out[idx] = __bfloat1622float2(val); - } else { - int i0 = rank * N0_local + i0_local; - int j0 = (i0 == 0) ? 0 : N0 - i0; - int j0_rank = j0 / N0_local; - int j0_local = j0 % N0_local; - int orig_i1 = N1 - i1; - - const __nv_bfloat162* remote_x = reinterpret_cast(symm_ptrs[j0_rank]); - __nv_bfloat162 val = remote_x[(b * N0_local + j0_local) * N1_half + orig_i1]; - - float2 fval = __bfloat1622float2(val); - fval.y = -fval.y; // Complex conjugate - out[idx] = fval; - } -} - -__global__ void transpose_bf16_kernel( - const uint64_t* __restrict__ symm_ptrs, - float2* __restrict__ out, - int B, int N0_local, int N1, int N1_local, - int rank, int world_size -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int N0 = N0_local * world_size; - int64_t total = (int64_t)B * N0 * N1_local; - if (idx >= total) return; - - int i1_local = idx % N1_local; - int tmp = idx / N1_local; - int i0 = tmp % N0; - int b = tmp / N0; - - int j0_rank = i0 / N0_local; - int i0_local = i0 % N0_local; - int i1 = rank * N1_local + i1_local; - - const __nv_bfloat162* remote_x1 = reinterpret_cast(symm_ptrs[j0_rank]); - __nv_bfloat162 val = remote_x1[(b * N0_local + i0_local) * N1 + i1]; - out[idx] = __bfloat1622float2(val); -} - -void launch_cast_complex_to_bf162(torch::Tensor in, torch::Tensor out) { - int64_t numel = in.numel(); - int threads = 256; - int blocks = (numel + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cast_complex_to_bf162_kernel<<>>( - reinterpret_cast(in.data_ptr>()), - reinterpret_cast<__nv_bfloat162*>(out.data_ptr()), - numel - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -void launch_conj_pad_2d_bf16( - torch::Tensor symm_ptrs_tensor, - torch::Tensor out, - int B, int N0_local, int N1_half, int N1, - int rank, int world_size -) { - int64_t total = (int64_t)B * N0_local * N1; - int threads = 256; - int blocks = (total + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - conj_pad_2d_bf16_kernel<<>>( - reinterpret_cast(symm_ptrs_tensor.data_ptr()), - reinterpret_cast(out.data_ptr>()), - B, N0_local, N1_half, N1, rank, world_size - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -void launch_transpose_bf16( - torch::Tensor symm_ptrs_tensor, - torch::Tensor out, - int B, int N0_local, int N1, int N1_local, - int rank, int world_size -) { - int N0 = N0_local * world_size; - int64_t total = (int64_t)B * N0 * N1_local; - int threads = 256; - int blocks = (total + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - transpose_bf16_kernel<<>>( - reinterpret_cast(symm_ptrs_tensor.data_ptr()), - reinterpret_cast(out.data_ptr>()), - B, N0_local, N1, N1_local, rank, world_size - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_cast_complex_to_bf162", &launch_cast_complex_to_bf162, "Cast complex64 to bfloat162 pairs"); - m.def("launch_conj_pad_2d_bf16", &launch_conj_pad_2d_bf16, "UVA conjugate pad 2d fetching bf16 pairs"); - m.def("launch_transpose_bf16", &launch_transpose_bf16, "UVA all-to-all transpose reading bf16 pairs"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("physicsnemo_irfft_bf16_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(key: str, shape: tuple, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - global _symm_cache - if key in _symm_cache: - buf, hdl, ptrs = _symm_cache[key] - if buf.shape == shape and buf.dtype == dtype: - return buf, hdl, ptrs - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - _symm_cache[key] = (buf, hdl, ptrs) - return buf, hdl, ptrs - -@torch.no_grad() -def solution( - x: torch.Tensor, - s: Optional[Sequence[int]], - dim: Sequence[int], - norm: str = "ortho", - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - if x.dtype != torch.complex64: - x = x.to(torch.complex64) - - dim0, dim1 = int(dim[0]) % x.ndim, int(dim[1]) % x.ndim - - if s is not None: - first_dim_size = int(s[0]) - last_dim_size = int(s[1]) - else: - first_dim_size = int(x.shape[dim0]) - last_dim_size = int(2 * (x.shape[dim1] - 1)) - - # Permute spatial dimensions to the innermost boundary for contiguous 3D C++ processing - perms = [i for i in range(x.ndim) if i not in (dim0, dim1)] + [dim0, dim1] - x_perm = x.permute(perms).contiguous() - - N0_local = x_perm.shape[-2] - N1_half = x_perm.shape[-1] - N0 = N0_local * world_size - N1 = last_dim_size - N1_local = N1 // world_size - B = x_perm.numel() // (N0_local * N1_half) - - ext = _get_ext() - - # Step 1: Push input shard as bfloat16 to Symmetric Memory - x_symm_shape = (B, N0_local, N1_half, 2) - x_symm, x_hdl, x_ptrs = _get_symm_state("x", x_symm_shape, torch.bfloat16, x.device, group) - ext.launch_cast_complex_to_bf162(x_perm, x_symm) - x_hdl.barrier(channel=0) - - # Step 2: Hermitian Symmetry Rebuilding (Direct remote fetch, fully bypasses gather & flips) - x_pad_perm = torch.empty((B, N0_local, N1), dtype=torch.complex64, device=x.device) - ext.launch_conj_pad_2d_bf16(x_ptrs, x_pad_perm, B, N0_local, N1_half, N1, rank, world_size) - - # Step 3: Complex-to-complex IFFT on dim1 - x1_perm = torch.fft.ifft(x_pad_perm, n=N1, dim=-1, norm=norm) - - # Step 4: Push transformed chunks as bfloat16 back to Symmetric Memory - x1_symm_shape = (B, N0_local, N1, 2) - x1_symm, x1_hdl, x1_ptrs = _get_symm_state("x1", x1_symm_shape, torch.bfloat16, x.device, group) - ext.launch_cast_complex_to_bf162(x1_perm, x1_symm) - x1_hdl.barrier(channel=0) - - # Step 5: All-to-all spatial transpose via UVA reads - x1_tran_perm = torch.empty((B, N0, N1_local), dtype=torch.complex64, device=x.device) - ext.launch_transpose_bf16(x1_ptrs, x1_tran_perm, B, N0_local, N1, N1_local, rank, world_size) - - # Step 6: Complex-to-complex IFFT on dim0 and extraction of reals - x2_perm = torch.fft.ifft(x1_tran_perm, n=first_dim_size, dim=-2, norm=norm) - out_perm = torch.real(x2_perm).contiguous() - - # Step 7: Inverse permutation to restore original spatial dimensions positions - inv_perms = [0] * x.ndim - for i, p in enumerate(perms): - inv_perms[p] = i - - out_shape_permuted = [x.shape[i] for i in perms[:-2]] + [first_dim_size, N1_local] - out_perm = out_perm.view(out_shape_permuted) - - return out_perm.permute(inv_perms).contiguous() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_cuda.py deleted file mode 100755 index 5e4b9a5..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/62_gsplat_3d_gaussian_splatting_cuda.py +++ /dev/null @@ -1,569 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Tuple -from torch import Tensor - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -__device__ __forceinline__ void quat_scale_to_covar(const float* quat, const float* scale, float covar[3][3]) { - float w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - float norm = sqrtf(w*w + x*x + y*y + z*z); - w /= norm; x /= norm; y /= norm; z /= norm; - - float R[3][3]; - R[0][0] = 1.0f - 2.0f*(y*y + z*z); R[0][1] = 2.0f*(x*y - w*z); R[0][2] = 2.0f*(x*z + w*y); - R[1][0] = 2.0f*(x*y + w*z); R[1][1] = 1.0f - 2.0f*(x*x + z*z); R[1][2] = 2.0f*(y*z - w*x); - R[2][0] = 2.0f*(x*z - w*y); R[2][1] = 2.0f*(y*z + w*x); R[2][2] = 1.0f - 2.0f*(x*x + y*y); - - float M[3][3]; - for (int r = 0; r < 3; ++r) { - for (int c = 0; c < 3; ++c) { - M[r][c] = R[r][c] * scale[c]; - } - } - - for (int r = 0; r < 3; ++r) { - for (int c = 0; c < 3; ++c) { - float sum = 0.0f; - for (int k = 0; k < 3; ++k) { - sum += M[r][k] * M[c][k]; - } - covar[r][c] = sum; - } - } -} - -__global__ void compute_valid_kernel( - const float* __restrict__ means, const float* __restrict__ quats, const float* __restrict__ scales, - const float* __restrict__ viewmats, const float* __restrict__ Ks, - int width, int height, - int N_local, int C_total, - float eps2d, float near_plane, float far_plane, - int* __restrict__ valid -) { - int n = blockIdx.x * blockDim.x + threadIdx.x; - if (n >= N_local) return; - - float mean[3] = {means[n*3], means[n*3+1], means[n*3+2]}; - float quat[4] = {quats[n*4], quats[n*4+1], quats[n*4+2], quats[n*4+3]}; - float scale[3] = {scales[n*3], scales[n*3+1], scales[n*3+2]}; - - float covar[3][3]; - quat_scale_to_covar(quat, scale, covar); - - for (int c = 0; c < C_total; ++c) { - float R[3][3], t[3], K[3][3]; - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - R[i][j] = viewmats[c * 16 + i * 4 + j]; - K[i][j] = Ks[c * 9 + i * 3 + j]; - } - t[i] = viewmats[c * 16 + i * 4 + 3]; - } - - float depth = R[2][0]*mean[0] + R[2][1]*mean[1] + R[2][2]*mean[2] + t[2]; - bool is_valid = true; - - if (depth <= near_plane || depth >= far_plane) { - is_valid = false; - } else { - float mean_c[3]; - for(int i=0; i<3; ++i) { - mean_c[i] = R[i][0]*mean[0] + R[i][1]*mean[1] + R[i][2]*mean[2] + t[i]; - } - - float cov_c[3][3]; - float tmp[3][3]; - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - tmp[i][j] = R[i][0]*covar[0][j] + R[i][1]*covar[1][j] + R[i][2]*covar[2][j]; - } - } - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - cov_c[i][j] = tmp[i][0]*R[j][0] + tmp[i][1]*R[j][1] + tmp[i][2]*R[j][2]; - } - } - - float tx = mean_c[0], ty = mean_c[1], tz = depth; - float tz2 = tz * tz; - - float fx = K[0][0]; float fy = K[1][1]; - float cx = K[0][2]; float cy = K[1][2]; - - float tan_fovx = 0.5f * width / fx; - float tan_fovy = 0.5f * height / fy; - float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; - float lim_x_neg = cx / fx + 0.3f * tan_fovx; - float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; - float lim_y_neg = cy / fy + 0.3f * tan_fovy; - - float clamp_x = tx / tz; - clamp_x = fmaxf(-lim_x_neg, fminf(clamp_x, lim_x_pos)); - tx = tz * clamp_x; - - float clamp_y = ty / tz; - clamp_y = fmaxf(-lim_y_neg, fminf(clamp_y, lim_y_pos)); - ty = tz * clamp_y; - - float J[2][3] = { - {fx / tz, 0.0f, -fx * tx / tz2}, - {0.0f, fy / tz, -fy * ty / tz2} - }; - - float cov2d[2][2]; - float tmp2[2][3]; - for(int i=0; i<2; ++i) { - for(int j=0; j<3; ++j) { - tmp2[i][j] = J[i][0]*cov_c[0][j] + J[i][1]*cov_c[1][j] + J[i][2]*cov_c[2][j]; - } - } - for(int i=0; i<2; ++i) { - for(int j=0; j<2; ++j) { - cov2d[i][j] = tmp2[i][0]*J[j][0] + tmp2[i][1]*J[j][1] + tmp2[i][2]*J[j][2]; - } - } - - float mean2d[2]; - mean2d[0] = (K[0][0]*mean_c[0] + K[0][1]*mean_c[1] + K[0][2]*mean_c[2]) / tz; - mean2d[1] = (K[1][0]*mean_c[0] + K[1][1]*mean_c[1] + K[1][2]*mean_c[2]) / tz; - - cov2d[0][0] += eps2d; - cov2d[1][1] += eps2d; - - int radii[2]; - radii[0] = (int)ceilf(3.33f * sqrtf(cov2d[0][0])); - radii[1] = (int)ceilf(3.33f * sqrtf(cov2d[1][1])); - - if (mean2d[0] + radii[0] <= 0 || mean2d[0] - radii[0] >= width || - mean2d[1] + radii[1] <= 0 || mean2d[1] - radii[1] >= height || - radii[0] <= 0 || radii[1] <= 0) { - is_valid = false; - } - } - valid[c * N_local + n] = is_valid ? 1 : 0; - } -} - -template -__global__ void project_and_push_kernel( - const float* __restrict__ means, const float* __restrict__ quats, const float* __restrict__ scales, - const float* __restrict__ opacities, const ColorT* __restrict__ colors, - const float* __restrict__ viewmats, const float* __restrict__ Ks, - const int* __restrict__ cam_dst_ranks, const int* __restrict__ cam_local_ids, - int width, int height, - const int* __restrict__ scan, const int* __restrict__ c_start, - int N_local, int C_total, int D_colors, - float eps2d, float near_plane, float far_plane, - int my_rank, int global_gaussian_offset, - const int* __restrict__ peer_recv_offsets, - const int64_t* __restrict__ peer_camera_ids_ptrs, - const int64_t* __restrict__ peer_gaussian_ids_ptrs, - const int64_t* __restrict__ peer_radii_ptrs, - const int64_t* __restrict__ peer_means2d_ptrs, - const int64_t* __restrict__ peer_depths_ptrs, - const int64_t* __restrict__ peer_conics_ptrs, - const int64_t* __restrict__ peer_opac_ptrs, - const int64_t* __restrict__ peer_colors_ptrs, - int world_size -) { - int n = blockIdx.x * blockDim.x + threadIdx.x; - if (n >= N_local) return; - - float mean[3] = {means[n*3], means[n*3+1], means[n*3+2]}; - float quat[4] = {quats[n*4], quats[n*4+1], quats[n*4+2], quats[n*4+3]}; - float scale[3] = {scales[n*3], scales[n*3+1], scales[n*3+2]}; - float opac = opacities[n]; - - float covar[3][3]; - quat_scale_to_covar(quat, scale, covar); - - for (int c = 0; c < C_total; ++c) { - int idx_1d = c * N_local + n; - int prev = (idx_1d == 0) ? 0 : scan[idx_1d - 1]; - if (scan[idx_1d] > prev) { - int local_idx = scan[idx_1d] - 1; - - float R[3][3], t[3], K[3][3]; - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - R[i][j] = viewmats[c * 16 + i * 4 + j]; - K[i][j] = Ks[c * 9 + i * 3 + j]; - } - t[i] = viewmats[c * 16 + i * 4 + 3]; - } - float fx = K[0][0]; float fy = K[1][1]; - float cx = K[0][2]; float cy = K[1][2]; - - float depth = R[2][0]*mean[0] + R[2][1]*mean[1] + R[2][2]*mean[2] + t[2]; - - float mean_c[3]; - for(int i=0; i<3; ++i) { - mean_c[i] = R[i][0]*mean[0] + R[i][1]*mean[1] + R[i][2]*mean[2] + t[i]; - } - - float cov_c[3][3]; - float tmp[3][3]; - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - tmp[i][j] = R[i][0]*covar[0][j] + R[i][1]*covar[1][j] + R[i][2]*covar[2][j]; - } - } - for(int i=0; i<3; ++i) { - for(int j=0; j<3; ++j) { - cov_c[i][j] = tmp[i][0]*R[j][0] + tmp[i][1]*R[j][1] + tmp[i][2]*R[j][2]; - } - } - - float tx = mean_c[0], ty = mean_c[1], tz = depth; - float tz2 = tz * tz; - - float tan_fovx = 0.5f * width / fx; - float tan_fovy = 0.5f * height / fy; - float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; - float lim_x_neg = cx / fx + 0.3f * tan_fovx; - float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; - float lim_y_neg = cy / fy + 0.3f * tan_fovy; - - float clamp_x = tx / tz; - clamp_x = fmaxf(-lim_x_neg, fminf(clamp_x, lim_x_pos)); - tx = tz * clamp_x; - - float clamp_y = ty / tz; - clamp_y = fmaxf(-lim_y_neg, fminf(clamp_y, lim_y_pos)); - ty = tz * clamp_y; - - float J[2][3] = { - {fx / tz, 0.0f, -fx * tx / tz2}, - {0.0f, fy / tz, -fy * ty / tz2} - }; - - float cov2d[2][2]; - float tmp2[2][3]; - for(int i=0; i<2; ++i) { - for(int j=0; j<3; ++j) { - tmp2[i][j] = J[i][0]*cov_c[0][j] + J[i][1]*cov_c[1][j] + J[i][2]*cov_c[2][j]; - } - } - for(int i=0; i<2; ++i) { - for(int j=0; j<2; ++j) { - cov2d[i][j] = tmp2[i][0]*J[j][0] + tmp2[i][1]*J[j][1] + tmp2[i][2]*J[j][2]; - } - } - - float mean2d[2]; - mean2d[0] = (K[0][0]*mean_c[0] + K[0][1]*mean_c[1] + K[0][2]*mean_c[2]) / tz; - mean2d[1] = (K[1][0]*mean_c[0] + K[1][1]*mean_c[1] + K[1][2]*mean_c[2]) / tz; - - cov2d[0][0] += eps2d; - cov2d[1][1] += eps2d; - - float det = cov2d[0][0]*cov2d[1][1] - cov2d[0][1]*cov2d[1][0]; - if (det < 1e-10f) det = 1e-10f; - - float conics[3]; - conics[0] = cov2d[1][1] / det; - conics[1] = -(cov2d[0][1] + cov2d[1][0]) / 2.0f / det; - conics[2] = cov2d[0][0] / det; - - int radii[2]; - radii[0] = (int)ceilf(3.33f * sqrtf(cov2d[0][0])); - radii[1] = (int)ceilf(3.33f * sqrtf(cov2d[1][1])); - - int dst_rank = cam_dst_ranks[c]; - int rank_c_start = c_start[dst_rank]; - int local_start = (rank_c_start == 0) ? 0 : scan[rank_c_start * N_local - 1]; - - int offset = peer_recv_offsets[dst_rank * world_size + my_rank]; - int peer_idx = offset + (local_idx - local_start); - - int* peer_camera_ids = (int*)peer_camera_ids_ptrs[dst_rank]; - int* peer_gaussian_ids = (int*)peer_gaussian_ids_ptrs[dst_rank]; - int* peer_radii = (int*)peer_radii_ptrs[dst_rank]; - float* peer_means2d = (float*)peer_means2d_ptrs[dst_rank]; - float* peer_depths = (float*)peer_depths_ptrs[dst_rank]; - float* peer_conics = (float*)peer_conics_ptrs[dst_rank]; - float* peer_opacities = (float*)peer_opacities_ptrs[dst_rank]; - ColorT* peer_colors = (ColorT*)peer_colors_ptrs[dst_rank]; - - peer_camera_ids[peer_idx] = cam_local_ids[c]; - peer_gaussian_ids[peer_idx] = global_gaussian_offset + n; - - peer_radii[peer_idx*2] = radii[0]; - peer_radii[peer_idx*2+1] = radii[1]; - - peer_means2d[peer_idx*2] = mean2d[0]; - peer_means2d[peer_idx*2+1] = mean2d[1]; - - peer_depths[peer_idx] = depth; - - peer_conics[peer_idx*3] = conics[0]; - peer_conics[peer_idx*3+1] = conics[1]; - peer_conics[peer_idx*3+2] = conics[2]; - - peer_opacities[peer_idx] = opac; - - for(int d=0; d>>( - means.data_ptr(), quats.data_ptr(), scales.data_ptr(), - viewmats.data_ptr(), Ks.data_ptr(), - width, height, N_local, C_total, - eps2d, near_plane, far_plane, valid.data_ptr() - ); -} - -void launch_project_and_push( - torch::Tensor means, torch::Tensor quats, torch::Tensor scales, - torch::Tensor opacities, torch::Tensor colors, - torch::Tensor viewmats, torch::Tensor Ks, - torch::Tensor cam_dst_ranks, torch::Tensor cam_local_ids, - int width, int height, - torch::Tensor scan, torch::Tensor c_start, - int N_local, int C_total, int D_colors, - float eps2d, float near_plane, float far_plane, - int my_rank, int global_gaussian_offset, - torch::Tensor peer_recv_offsets, - torch::Tensor peer_cam_ptrs, torch::Tensor peer_gauss_ptrs, - torch::Tensor peer_radii_ptrs, torch::Tensor peer_means2d_ptrs, - torch::Tensor peer_depths_ptrs, torch::Tensor peer_conics_ptrs, - torch::Tensor peer_opac_ptrs, torch::Tensor peer_colors_ptrs, - int world_size -) { - int threads = 256; - int blocks = (N_local + threads - 1) / threads; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - #define LAUNCH_PUSH(COLOR_T) \ - project_and_push_kernel<<>>( \ - means.data_ptr(), quats.data_ptr(), scales.data_ptr(), \ - opacities.data_ptr(), (const COLOR_T*)colors.data_ptr(), \ - viewmats.data_ptr(), Ks.data_ptr(), \ - cam_dst_ranks.data_ptr(), cam_local_ids.data_ptr(), \ - width, height, \ - scan.data_ptr(), c_start.data_ptr(), \ - N_local, C_total, D_colors, \ - eps2d, near_plane, far_plane, \ - my_rank, global_gaussian_offset, \ - peer_recv_offsets.data_ptr(), \ - peer_cam_ptrs.data_ptr(), peer_gauss_ptrs.data_ptr(), \ - peer_radii_ptrs.data_ptr(), peer_means2d_ptrs.data_ptr(), \ - peer_depths_ptrs.data_ptr(), peer_conics_ptrs.data_ptr(), \ - peer_opac_ptrs.data_ptr(), peer_colors_ptrs.data_ptr(), \ - world_size \ - ) - - if (colors.scalar_type() == at::ScalarType::BFloat16) { - LAUNCH_PUSH(__nv_bfloat16); - } else { - LAUNCH_PUSH(float); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_compute_valid", &launch_compute_valid, "Fused projection pass 1: validity and counting"); - m.def("launch_project_and_push", &launch_project_and_push, "Fused projection pass 2: push to symmetric peer memory"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gsplat_dist_fused_ext", CUDA_SRC) - return _ext - -def _all_gather_tensor_list(world_size: int, tensor_list: list[Tensor]) -> list[Tensor]: - if world_size == 1: return tensor_list - N = len(tensor_list[0]) - data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) - sizes = [t.numel() // N for t in tensor_list] - collected = [torch.empty_like(data) for _ in range(world_size)] - dist.all_gather(collected, data) - collected = torch.cat(collected, dim=0) - out_tensor_tuple = torch.split(collected, sizes, dim=-1) - out_tensor_list = [] - for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): - out_tensor = out_tensor.view(-1, *tensor.shape[1:]) - out_tensor_list.append(out_tensor) - return out_tensor_list - -@torch.no_grad() -def solution( - means: Tensor, - quats: Tensor, - scales: Tensor, - opacities: Tensor, - colors: Tensor, - viewmats: Tensor, - Ks: Tensor, - image_width: int, - image_height: int, - eps2d: float = 0.3, - near_plane: float = 0.01, - far_plane: float = 1e10, - camera_model: str = "pinhole", -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - - assert dist.is_initialized(), "torch.distributed must be initialized" - world_rank = dist.get_rank() - world_size = dist.get_world_size() - device = means.device - - N = means.shape[0] - C = viewmats.shape[0] - D = colors.shape[1] - - # Quick pre-gather global shapes & configurations - N_t = torch.tensor([N], dtype=torch.int32, device=device) - N_world = _all_gather_tensor_list(world_size, [N_t])[0].flatten().tolist() - C_world = [C] * world_size - global_gaussian_offset = sum(N_world[:world_rank]) - C_total = sum(C_world) - - c_start = [sum(C_world[:r]) for r in range(world_size)] - c_start_t = torch.tensor(c_start, dtype=torch.int32, device=device) - - all_viewmats = _all_gather_tensor_list(world_size, [viewmats])[0].float().contiguous() - all_Ks = _all_gather_tensor_list(world_size, [Ks])[0].float().contiguous() - - cam_dst_ranks = torch.zeros(C_total, dtype=torch.int32, device=device) - cam_local_ids = torch.zeros(C_total, dtype=torch.int32, device=device) - for r in range(world_size): - start, end = sum(C_world[:r]), sum(C_world[:r+1]) - cam_dst_ranks[start:end] = r - cam_local_ids[start:end] = torch.arange(C_world[r], dtype=torch.int32, device=device) - - means_f32 = means.float().contiguous() - quats_f32 = quats.float().contiguous() - scales_f32 = scales.float().contiguous() - opacities_f32 = opacities.float().contiguous() - colors_c = colors.contiguous() - - ext = _get_ext() - - # Pass 1: Local lightweight projection to determine precise required memory per chunk - valid = torch.zeros(C_total * N, dtype=torch.int32, device=device) - ext.launch_compute_valid( - means_f32, quats_f32, scales_f32, all_viewmats, all_Ks, - image_width, image_height, N, C_total, - eps2d, near_plane, far_plane, valid - ) - - # Deterministic offsets (maintains ideal grouping and avoids expensive atomic reductions) - scan = torch.cumsum(valid, dim=0) - - send_counts = torch.zeros(world_size, dtype=torch.int32, device=device) - for r in range(world_size): - end_idx = sum(C_world[:r+1]) * N - 1 - start_idx = sum(C_world[:r]) * N - 1 - end_val = scan[end_idx].item() if end_idx >= 0 else 0 - start_val = scan[start_idx].item() if start_idx >= 0 else 0 - send_counts[r] = end_val - start_val - - # Setup the memory structures on peers dynamically - recv_counts = torch.zeros(world_size, dtype=torch.int32, device=device) - dist.all_to_all_single(recv_counts, send_counts) - - recv_offsets = torch.cumsum(recv_counts, dim=0) - recv_counts - recv_total = recv_counts.sum().item() - - all_recv_offsets = [torch.zeros(world_size, dtype=torch.int32, device=device) for _ in range(world_size)] - dist.all_gather(all_recv_offsets, recv_offsets) - peer_recv_offsets = torch.stack(all_recv_offsets) - - all_recv_totals = torch.zeros(world_size, dtype=torch.int32, device=device) - dist.all_gather_into_tensor(all_recv_totals, torch.tensor([recv_total], dtype=torch.int32, device=device)) - - # Single massive symmetric memory allocation split into typed views - bytes_per_elem = 44 + D * colors_c.element_size() - buf = symm_mem.empty(recv_total * bytes_per_elem, dtype=torch.uint8, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - def slice_buffer(base_ptr_int, r_total): - ptrs = {} - curr = base_ptr_int - ptrs['cam'] = curr; curr += r_total * 4 - ptrs['gauss'] = curr; curr += r_total * 4 - ptrs['radii'] = curr; curr += r_total * 8 - ptrs['means2d'] = curr; curr += r_total * 8 - ptrs['depths'] = curr; curr += r_total * 4 - ptrs['conics'] = curr; curr += r_total * 12 - ptrs['opac'] = curr; curr += r_total * 4 - ptrs['colors'] = curr; curr += r_total * D * colors_c.element_size() - return ptrs - - peer_cam_ptrs, peer_gauss_ptrs, peer_radii_ptrs = [], [], [] - peer_means2d_ptrs, peer_depths_ptrs, peer_conics_ptrs = [], [], [] - peer_opac_ptrs, peer_colors_ptrs = [], [] - - for p in range(world_size): - ptrs = slice_buffer(hdl.buffer_ptrs[p], all_recv_totals[p].item()) - peer_cam_ptrs.append(ptrs['cam']); peer_gauss_ptrs.append(ptrs['gauss']) - peer_radii_ptrs.append(ptrs['radii']); peer_means2d_ptrs.append(ptrs['means2d']) - peer_depths_ptrs.append(ptrs['depths']); peer_conics_ptrs.append(ptrs['conics']) - peer_opac_ptrs.append(ptrs['opac']); peer_colors_ptrs.append(ptrs['colors']) - - # Pass 2: Recompute projection and synchronously stream the results over NVLink straight to the receiver - ext.launch_project_and_push( - means_f32, quats_f32, scales_f32, opacities_f32, colors_c, - all_viewmats, all_Ks, cam_dst_ranks, cam_local_ids, - image_width, image_height, - scan, c_start_t, N, C_total, D, eps2d, near_plane, far_plane, - world_rank, global_gaussian_offset, peer_recv_offsets.flatten(), - torch.tensor(peer_cam_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_gauss_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_radii_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_means2d_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_depths_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_conics_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_opac_ptrs, dtype=torch.int64, device=device), - torch.tensor(peer_colors_ptrs, dtype=torch.int64, device=device), - world_size - ) - - hdl.barrier(channel=0) - - # Finally slice out local incoming views - curr = 0 - cam_ids_out = buf[curr : curr + recv_total * 4].view(torch.int32); curr += recv_total * 4 - gauss_ids_out = buf[curr : curr + recv_total * 4].view(torch.int32); curr += recv_total * 4 - radii_out = buf[curr : curr + recv_total * 8].view(torch.int32).view(recv_total, 2); curr += recv_total * 8 - means2d_out = buf[curr : curr + recv_total * 8].view(torch.float32).view(recv_total, 2); curr += recv_total * 8 - depths_out = buf[curr : curr + recv_total * 4].view(torch.float32); curr += recv_total * 4 - conics_out = buf[curr : curr + recv_total * 12].view(torch.float32).view(recv_total, 3); curr += recv_total * 12 - opacities_out = buf[curr : curr + recv_total * 4].view(torch.float32); curr += recv_total * 4 - colors_out = buf[curr : curr + recv_total * D * colors_c.element_size()].view(colors_c.dtype).view(recv_total, D) - - return ( - cam_ids_out, gauss_ids_out, radii_out, - means2d_out, depths_out, conics_out, - opacities_out, colors_out - ) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_cuda.py deleted file mode 100755 index 7bb49fa..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/63_torchharmonics_spherical_convolution_cuda.py +++ /dev/null @@ -1,464 +0,0 @@ -""" -Optimized DISCO spherical convolution replacing PyTorch collectives and loops with custom CUDA. - -Strategy: -1. **Fused All-to-All Push**: The azimuth all-to-all is implemented as a direct peer-to-peer - UVA push kernel, eliminating `torch.split`, `contiguous`, and `torch.cat` overheads. -2. **Fused Shifted SpMM**: The expensive `torch.roll` and loop of `torch.bmm` are folded - into a single CSR sparse matrix-vector multiplication kernel. -3. **Lock-Free Reduce-Scatter**: The polar all-reduce + split is replaced by a lock-free - UVA push kernel into peer buffers, followed by a local reduction kernel. -4. **BMM Layout Fusion**: The final azimuth scatter pushes data directly into the permuted - layout required for the grouped channel-mixing BMM. -""" - -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void a2a_azimuth_fwd_kernel( - const __nv_bfloat16* __restrict__ x, - const long long* __restrict__ dest_ptrs, - const int* __restrict__ az_global_ranks, - const int* __restrict__ C_offsets, - const int* __restrict__ C_splits, - int B, int C, int nlat_in, int my_lon_in, - int nlon_in, int my_lon_offset, int az_size -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = B * C * nlat_in * my_lon_in; - if (idx >= total) return; - - int lon = idx % my_lon_in; - int tmp = idx / my_lon_in; - int lat = tmp % nlat_in; - tmp = tmp / nlat_in; - int c = tmp % C; - int b = tmp / C; - - int dest_az_rank = 0; - for (int i = 0; i < az_size; ++i) { - if (c >= C_offsets[i] && c < C_offsets[i] + C_splits[i]) { - dest_az_rank = i; - break; - } - } - - int c_local = c - C_offsets[dest_az_rank]; - int dest_global_rank = az_global_ranks[dest_az_rank]; - __nv_bfloat16* dest_ptr = (__nv_bfloat16*)dest_ptrs[dest_global_rank]; - - int dest_C = C_splits[dest_az_rank]; - int dest_idx = ((b * dest_C + c_local) * nlat_in + lat) * nlon_in + (my_lon_offset + lon); - dest_ptr[dest_idx] = x[idx]; -} - -__global__ void spmm_shift_csr_kernel( - const int* __restrict__ crow_indices, - const int* __restrict__ col_indices, - const float* __restrict__ values, - const __nv_bfloat16* __restrict__ x, - float* __restrict__ y, - int R, int N, int nlat_in, int nlon_in, int nlon_out, int pscale, int kernel_size, int nlat_out -) { - int r = blockIdx.x * blockDim.x + threadIdx.x; - int n_idx = blockIdx.y * blockDim.y + threadIdx.y; - int pout = blockIdx.z; - - if (r >= R || n_idx >= N) return; - - int row_start = crow_indices[r]; - int row_end = crow_indices[r+1]; - - float sum = 0.0f; - for (int nz = row_start; nz < row_end; ++nz) { - int in_idx = col_indices[nz]; - float val = values[nz]; - - int lat_in = in_idx / nlon_in; - int lon_in = in_idx % nlon_in; - int lon_shifted = (lon_in + pout * pscale) % nlon_in; - - float x_val = __bfloat162float(x[(n_idx * nlat_in + lat_in) * nlon_in + lon_shifted]); - sum += val * x_val; - } - - int k = r / nlat_out; - int lat_out = r % nlat_out; - int y_idx = ((n_idx * kernel_size + k) * nlat_out + lat_out) * nlon_out + pout; - y[y_idx] = sum; -} - -__global__ void push_rs_chunk_kernel( - const float* __restrict__ local_y, - const long long* __restrict__ dest_ptrs, - const int* __restrict__ polar_global_ranks, - const int* __restrict__ nlat_out_offsets, - const int* __restrict__ nlat_out_splits, - int N, int kernel_size, int nlat_out, int nlon_out, - int my_polar_rank, int polar_size -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = N * kernel_size * nlat_out * nlon_out; - if (idx >= total) return; - - int lon = idx % nlon_out; - int tmp = idx / nlon_out; - int lat = tmp % nlat_out; - tmp = tmp / nlat_out; - int k = tmp % kernel_size; - int n = tmp / kernel_size; - - int dest_polar_rank = 0; - for (int i = 0; i < polar_size; ++i) { - if (lat >= nlat_out_offsets[i] && lat < nlat_out_offsets[i] + nlat_out_splits[i]) { - dest_polar_rank = i; - break; - } - } - - int dest_global_rank = polar_global_ranks[dest_polar_rank]; - float* dest_ptr = (float*)dest_ptrs[dest_global_rank]; - - int lat_local = lat - nlat_out_offsets[dest_polar_rank]; - int dest_nlat = nlat_out_splits[dest_polar_rank]; - - int remote_idx = (((my_polar_rank * N + n) * kernel_size + k) * dest_nlat + lat_local) * nlon_out + lon; - dest_ptr[remote_idx] = local_y[idx]; -} - -__global__ void reduce_rs_chunk_kernel( - const float* __restrict__ local_buf, - __nv_bfloat16* __restrict__ out, - int polar_size, int chunk_size -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= chunk_size) return; - - float sum = 0.0f; - for (int p = 0; p < polar_size; ++p) { - sum += local_buf[p * chunk_size + idx]; - } - out[idx] = __float2bfloat16(sum); -} - -__global__ void a2a_azimuth_bwd_kernel( - const __nv_bfloat16* __restrict__ local_x, - const long long* __restrict__ dest_ptrs, - const int* __restrict__ az_global_ranks, - const int* __restrict__ lon_out_offsets, - const int* __restrict__ lon_out_splits, - int B, int C_local, int kernel_size, int nlat_out_local, int nlon_out, - int my_C_offset, int az_size, int full_C, int groups, int groupsize -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = B * C_local * kernel_size * nlat_out_local * nlon_out; - if (idx >= total) return; - - int lon = idx % nlon_out; - int tmp = idx / nlon_out; - int lat = tmp % nlat_out_local; - tmp = tmp / nlat_out_local; - int k = tmp % kernel_size; - tmp = tmp / kernel_size; - int c = tmp % C_local; - int b = tmp / C_local; - - int dest_az_rank = 0; - for (int i = 0; i < az_size; ++i) { - if (lon >= lon_out_offsets[i] && lon < lon_out_offsets[i] + lon_out_splits[i]) { - dest_az_rank = i; - break; - } - } - - int lon_local = lon - lon_out_offsets[dest_az_rank]; - int dest_global_rank = az_global_ranks[dest_az_rank]; - __nv_bfloat16* dest_ptr = (__nv_bfloat16*)dest_ptrs[dest_global_rank]; - - int dest_nlon = lon_out_splits[dest_az_rank]; - int c_global = my_C_offset + c; - - int g = c_global / groupsize; - int c_in_g = c_global % groupsize; - int dest_HW = nlat_out_local * dest_nlon; - int hw = lat * dest_nlon + lon_local; - - int dest_idx = (g * (groupsize * kernel_size) + (c_in_g * kernel_size + k)) * (B * dest_HW) + (b * dest_HW + hw); - dest_ptr[dest_idx] = local_x[idx]; -} - -void launch_a2a_fwd( - torch::Tensor x, torch::Tensor dest_ptrs, torch::Tensor az_global_ranks, - torch::Tensor C_offsets, torch::Tensor C_splits, - int B, int C, int nlat_in, int my_lon_in, int nlon_in, int my_lon_offset, int az_size -) { - int total = B * C * nlat_in * my_lon_in; - if (total == 0) return; - int threads = 256; - int blocks = (total + threads - 1) / threads; - a2a_azimuth_fwd_kernel<<>>( - (__nv_bfloat16*)x.data_ptr(), - (const long long*)dest_ptrs.data_ptr(), - az_global_ranks.data_ptr(), - C_offsets.data_ptr(), - C_splits.data_ptr(), - B, C, nlat_in, my_lon_in, nlon_in, my_lon_offset, az_size - ); -} - -void launch_spmm( - torch::Tensor crow, torch::Tensor col, torch::Tensor values, - torch::Tensor x, torch::Tensor y, - int R, int N, int nlat_in, int nlon_in, int nlon_out, int pscale, int kernel_size, int nlat_out -) { - if (R == 0 || N == 0) return; - dim3 threads(128, 1, 1); - dim3 blocks((R + threads.x - 1) / threads.x, N, nlon_out); - spmm_shift_csr_kernel<<>>( - crow.data_ptr(), - col.data_ptr(), - values.data_ptr(), - (__nv_bfloat16*)x.data_ptr(), - y.data_ptr(), - R, N, nlat_in, nlon_in, nlon_out, pscale, kernel_size, nlat_out - ); -} - -void launch_push_rs( - torch::Tensor local_y, torch::Tensor dest_ptrs, torch::Tensor polar_global_ranks, - torch::Tensor nlat_out_offsets, torch::Tensor nlat_out_splits, - int N, int kernel_size, int nlat_out, int nlon_out, - int my_polar_rank, int polar_size -) { - int total = N * kernel_size * nlat_out * nlon_out; - if (total == 0) return; - int threads = 256; - int blocks = (total + threads - 1) / threads; - push_rs_chunk_kernel<<>>( - local_y.data_ptr(), - (const long long*)dest_ptrs.data_ptr(), - polar_global_ranks.data_ptr(), - nlat_out_offsets.data_ptr(), - nlat_out_splits.data_ptr(), - N, kernel_size, nlat_out, nlon_out, - my_polar_rank, polar_size - ); -} - -void launch_reduce_rs( - torch::Tensor local_buf, torch::Tensor out, - int polar_size, int chunk_size -) { - if (chunk_size == 0) return; - int threads = 256; - int blocks = (chunk_size + threads - 1) / threads; - reduce_rs_chunk_kernel<<>>( - local_buf.data_ptr(), - (__nv_bfloat16*)out.data_ptr(), - polar_size, chunk_size - ); -} - -void launch_a2a_bwd( - torch::Tensor local_x, torch::Tensor dest_ptrs, torch::Tensor az_global_ranks, - torch::Tensor lon_out_offsets, torch::Tensor lon_out_splits, - int B, int C_local, int kernel_size, int nlat_out_local, int nlon_out, - int my_C_offset, int az_size, int full_C, int groups, int groupsize -) { - int total = B * C_local * kernel_size * nlat_out_local * nlon_out; - if (total == 0) return; - int threads = 256; - int blocks = (total + threads - 1) / threads; - a2a_azimuth_bwd_kernel<<>>( - (__nv_bfloat16*)local_x.data_ptr(), - (const long long*)dest_ptrs.data_ptr(), - az_global_ranks.data_ptr(), - lon_out_offsets.data_ptr(), - lon_out_splits.data_ptr(), - B, C_local, kernel_size, nlat_out_local, nlon_out, - my_C_offset, az_size, full_C, groups, groupsize - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_a2a_fwd", &launch_a2a_fwd); - m.def("launch_spmm", &launch_spmm); - m.def("launch_push_rs", &launch_push_rs); - m.def("launch_reduce_rs", &launch_reduce_rs); - m.def("launch_a2a_bwd", &launch_a2a_bwd); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("disco_s2_fused_ext", CUDA_SRC) - return _ext - -def _compute_split_shapes(size: int, num_chunks: int) -> List[int]: - if num_chunks == 1: - return [size] - chunk_size = (size + num_chunks - 1) // num_chunks - last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) - if last_chunk_size == 0: - chunk_size = size // num_chunks - last_chunk_size = size - chunk_size * (num_chunks - 1) - return [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] - - -_symm_cache = {} -def _get_symm_mem(shape, dtype): - key = (shape, dtype) - if key not in _symm_cache: - buf = symm_mem.empty(shape, dtype=dtype, device='cuda') - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device='cuda') - _symm_cache[key] = (buf, ptrs) - return _symm_cache[key] - -_psi_cache = {} - - -@torch.no_grad() -def solution( - x: torch.Tensor, - psi: torch.Tensor, - weight: torch.Tensor, - groups: int, - nlon_out: int, - nlon_in: int, - azimuth_group: Optional[dist.ProcessGroup] = None, - polar_group: Optional[dist.ProcessGroup] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - azimuth_group = azimuth_group or dist.group.WORLD - polar_group = polar_group or dist.group.WORLD - - az_size = dist.get_world_size(group=azimuth_group) - az_rank = dist.get_rank(group=azimuth_group) - polar_size = dist.get_world_size(group=polar_group) - polar_rank = dist.get_rank(group=polar_group) - - B, full_C, nlat_in, my_lon_in = x.shape - kernel_size, nlat_out, _ = psi.shape - pscale = nlon_in // nlon_out - - # Meta constants & shapes - C_splits = _compute_split_shapes(full_C, az_size) - C_offsets = [sum(C_splits[:i]) for i in range(az_size)] - my_C = C_splits[az_rank] - - lon_in_splits = _compute_split_shapes(nlon_in, az_size) - lon_in_offsets = [sum(lon_in_splits[:i]) for i in range(az_size)] - - nlat_out_splits = _compute_split_shapes(nlat_out, polar_size) - nlat_out_offsets = [sum(nlat_out_splits[:i]) for i in range(polar_size)] - my_nlat_out = nlat_out_splits[polar_rank] - - lon_out_splits = _compute_split_shapes(nlon_out, az_size) - lon_out_offsets = [sum(lon_out_splits[:i]) for i in range(az_size)] - my_lon_out = lon_out_splits[az_rank] - - def _to_t(lst): return torch.tensor(lst, dtype=torch.int32, device='cuda') - - az_global_ranks = _to_t([dist.get_global_rank(azimuth_group, i) for i in range(az_size)]) - polar_global_ranks = _to_t([dist.get_global_rank(polar_group, i) for i in range(polar_size)]) - - C_splits_t = _to_t(C_splits) - C_offsets_t = _to_t(C_offsets) - lon_out_splits_t = _to_t(lon_out_splits) - lon_out_offsets_t = _to_t(lon_out_offsets) - nlat_out_splits_t = _to_t(nlat_out_splits) - nlat_out_offsets_t = _to_t(nlat_out_offsets) - - # Convert psi to CSR once - psi_ptr = psi.data_ptr() - if psi_ptr not in _psi_cache: - if psi.is_sparse: - psi_coo = psi.coalesce() - idx, vals = psi_coo.indices(), psi_coo.values() - else: - idx = psi.nonzero(as_tuple=False).t().contiguous() - vals = psi[idx[0], idx[1], idx[2]] - - psi_csr = torch.sparse_coo_tensor( - torch.stack([idx[0] * nlat_out + idx[1], idx[2]]), vals.float(), - size=(kernel_size * nlat_out, nlat_in * nlon_in) - ).coalesce().to_sparse_csr() - _psi_cache[psi_ptr] = (psi_csr.crow_indices().int(), psi_csr.col_indices().int(), psi_csr.values()) - - crow, col, vals = _psi_cache[psi_ptr] - - # --- Step 1: Azimuth FWD A2A Push --- - symm_x_az, ptrs_x_az = _get_symm_mem((B, my_C, nlat_in, nlon_in), torch.bfloat16) - dist.barrier() - if az_size > 1: - _get_ext().launch_a2a_fwd( - x, ptrs_x_az, az_global_ranks, C_offsets_t, C_splits_t, - B, full_C, nlat_in, my_lon_in, nlon_in, lon_in_offsets[az_rank], az_size - ) - dist.barrier() - curr_x = symm_x_az - else: - curr_x = x - - # --- Step 2: Fused SpMM Shift Contraction --- - y_partial = torch.empty((B * my_C, kernel_size, nlat_out, nlon_out), dtype=torch.float32, device='cuda') - _get_ext().launch_spmm( - crow, col, vals, curr_x, y_partial, - kernel_size * nlat_out, B * my_C, nlat_in, nlon_in, nlon_out, pscale, kernel_size, nlat_out - ) - - # --- Step 3 & 4: Lock-Free Polar Reduce-Scatter --- - symm_y_polar, ptrs_y_polar = _get_symm_mem((polar_size, B * my_C, kernel_size, my_nlat_out, nlon_out), torch.float32) - dist.barrier() - if polar_size > 1: - _get_ext().launch_push_rs( - y_partial, ptrs_y_polar, polar_global_ranks, nlat_out_offsets_t, nlat_out_splits_t, - B * my_C, kernel_size, nlat_out, nlon_out, polar_rank, polar_size - ) - dist.barrier() - y_local = torch.empty((B * my_C, kernel_size, my_nlat_out, nlon_out), dtype=torch.bfloat16, device='cuda') - _get_ext().launch_reduce_rs(symm_y_polar, y_local, polar_size, y_local.numel()) - else: - y_local = y_partial.to(torch.bfloat16) - - # --- Step 5: Azimuth BWD A2A Push (Directly to BMM layout) --- - groupsize = full_C // groups - symm_out_az, ptrs_out_az = _get_symm_mem((groups, groupsize * kernel_size, B * my_nlat_out * my_lon_out), torch.bfloat16) - dist.barrier() - - if az_size > 1: - _get_ext().launch_a2a_bwd( - y_local, ptrs_out_az, az_global_ranks, lon_out_offsets_t, lon_out_splits_t, - B, my_C, kernel_size, my_nlat_out, nlon_out, C_offsets[az_rank], az_size, full_C, groups, groupsize - ) - dist.barrier() - bmm_in = symm_out_az - else: - # If no azimuth distribution, emulate BMM layout preparation locally - bmm_in = y_local.view(B, groups, groupsize, kernel_size, my_nlat_out, my_lon_out) - bmm_in = bmm_in.permute(1, 2, 3, 0, 4, 5).reshape(groups, groupsize * kernel_size, B * my_nlat_out * my_lon_out) - - # --- Step 6 & 7: Grouped Channel Mixing & Bias --- - weight_reshaped = weight.reshape(groups, -1, weight.shape[1] * weight.shape[2]).to(torch.bfloat16) - out = torch.bmm(weight_reshaped, bmm_in) - - out = out.view(groups, -1, B, my_nlat_out, my_lon_out) - out = out.permute(2, 0, 1, 3, 4).reshape(B, -1, my_nlat_out, my_lon_out) - - if bias is not None: - out = out + bias.view(1, -1, 1, 1).to(out.dtype) - - return out.contiguous() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_cuda.py deleted file mode 100755 index 29c8b20..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/64_deepmd_kalman_filter_optimizer_cuda.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -Strategy: -1. **Device-side communication**: Replaced `dist.all_reduce` for the scalar `tmp` and `dist.all_gather` for parameter updates with direct UVA pointer accesses from `torch.distributed._symmetric_memory`. -2. **Fused Compute & Overlap**: Grouped computations into batched CUDA kernels. Kernel 1 fuses GEMV (`P_i @ H_i`) with dot product, directly using atomic additions on a symmetric memory scalar. Kernel 2 reads all peers' scalars directly to compute the denominator, applies parameter and covariance updates, and natively writes updated weights to a symmetric buffer. -3. **Optimized Shape Gather**: `all_gather_object` is extremely slow and only needed once (shapes stay constant across iterations); this layout is cached so hot-path calls only do memory accesses. -4. **Gather offload**: A third CUDA kernel uses UVA reads to efficiently aggregate peer data into a contiguous output buffer, skipping PyTorch allocations and NCCL collective latency. -""" - -from typing import List, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void kernel1( - const int64_t* __restrict__ P_ptrs, - const int64_t* __restrict__ H_ptrs, - __nv_bfloat16* __restrict__ K_flat, - const int* __restrict__ N_array, - const int* __restrict__ Offsets_array, - float lam, - float* __restrict__ symm_tmp -) { - int block_idx = blockIdx.x; - int n = N_array[block_idx]; - int offset = Offsets_array[block_idx]; - - const __nv_bfloat16* P = (const __nv_bfloat16*)P_ptrs[block_idx]; - const __nv_bfloat16* H = (const __nv_bfloat16*)H_ptrs[block_idx]; - __nv_bfloat16* K = K_flat + offset; - - int tid = threadIdx.x; - int threads = blockDim.x; - - float local_dot = 0.0f; - - for (int row = tid; row < n; row += threads) { - float sum = 0.0f; - for (int col = 0; col < n; col++) { - sum += __bfloat162float(P[row * n + col]) * __bfloat162float(H[col]); - } - K[row] = __float2bfloat16(sum); - local_dot += sum * __bfloat162float(H[row]); - } - - static __shared__ float shared_dot[32]; - int lane = tid % 32; - int wid = tid / 32; - - #pragma unroll - for (int offset_shfl = 16; offset_shfl > 0; offset_shfl /= 2) { - local_dot += __shfl_down_sync(0xffffffff, local_dot, offset_shfl); - } - - if (lane == 0) { - shared_dot[wid] = local_dot; - } - __syncthreads(); - - if (tid < 32) { - float val = (tid < (threads + 31) / 32) ? shared_dot[tid] : 0.0f; - #pragma unroll - for (int offset_shfl = 16; offset_shfl > 0; offset_shfl /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset_shfl); - } - if (tid == 0) { - atomicAdd(symm_tmp, val + lam); - } - } -} - -__global__ void kernel2( - const int64_t* __restrict__ symm_tmp_ptrs, - int world_size, - const int64_t* __restrict__ P_ptrs, - const int64_t* __restrict__ W_ptrs, - const __nv_bfloat16* __restrict__ K_flat, - __nv_bfloat16* __restrict__ symm_weights, - const int* __restrict__ N_array, - const int* __restrict__ Offsets_array, - const __nv_bfloat16* __restrict__ err_ptr, - float lam -) { - __shared__ float A; - __shared__ float err; - - int block_idx = blockIdx.x; - int n = N_array[block_idx]; - int offset = Offsets_array[block_idx]; - - if (threadIdx.x == 0) { - float global_tmp = 0.0f; - for (int r = 0; r < world_size; r++) { - float* peer_tmp = (float*)symm_tmp_ptrs[r]; - global_tmp += *peer_tmp; - } - A = 1.0f / global_tmp; - err = __bfloat162float(*err_ptr); - } - __syncthreads(); - - float local_A = A; - float local_err = err; - - __nv_bfloat16* P = (__nv_bfloat16*)P_ptrs[block_idx]; - __nv_bfloat16* W = (__nv_bfloat16*)W_ptrs[block_idx]; - const __nv_bfloat16* K = K_flat + offset; - __nv_bfloat16* W_out = symm_weights + offset; - - int tid = threadIdx.x; - int threads = blockDim.x; - - for (int row = tid; row < n; row += threads) { - float w_val = __bfloat162float(W[row]); - float k_val = __bfloat162float(K[row]); - float new_w = w_val + local_A * local_err * k_val; - W[row] = __float2bfloat16(new_w); - W_out[row] = __float2bfloat16(new_w); - } - - int total_elements = n * n; - for (int idx = tid; idx < total_elements; idx += threads) { - int row = idx / n; - int col = idx % n; - float p_val = __bfloat162float(P[idx]); - float kr = __bfloat162float(K[row]); - float kc = __bfloat162float(K[col]); - - float new_p = (p_val - local_A * kr * kc) / lam; - P[idx] = __float2bfloat16(new_p); - } -} - -__global__ void gather_kernel( - const int64_t* __restrict__ symm_weights_ptrs, - __nv_bfloat16* __restrict__ out, - const int* __restrict__ rank_offsets, - const int* __restrict__ rank_sizes, - int world_size -) { - int block_idx = blockIdx.x; - if (block_idx >= world_size) return; - - int offset = rank_offsets[block_idx]; - int size = rank_sizes[block_idx]; - const __nv_bfloat16* src = (const __nv_bfloat16*)symm_weights_ptrs[block_idx]; - __nv_bfloat16* dst = out + offset; - - int tid = threadIdx.x; - int threads = blockDim.x; - - for (int i = tid; i < size; i += threads) { - dst[i] = src[i]; - } -} - -void launch_kernel1( - torch::Tensor P_ptrs, torch::Tensor H_ptrs, torch::Tensor K_flat, - torch::Tensor N_array, torch::Tensor Offsets_array, float lam, - torch::Tensor symm_tmp, int weights_num -) { - int threads = 256; - int blocks = weights_num; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - kernel1<<>>( - P_ptrs.data_ptr(), - H_ptrs.data_ptr(), - (__nv_bfloat16*)K_flat.data_ptr(), - N_array.data_ptr(), - Offsets_array.data_ptr(), - lam, - symm_tmp.data_ptr() - ); -} - -void launch_kernel2( - torch::Tensor symm_tmp_ptrs, int world_size, - torch::Tensor P_ptrs, torch::Tensor W_ptrs, torch::Tensor K_flat, - torch::Tensor symm_weights, torch::Tensor N_array, torch::Tensor Offsets_array, - torch::Tensor err_tensor, float lam, int weights_num -) { - int threads = 256; - int blocks = weights_num; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - kernel2<<>>( - symm_tmp_ptrs.data_ptr(), - world_size, - P_ptrs.data_ptr(), - W_ptrs.data_ptr(), - (__nv_bfloat16*)K_flat.data_ptr(), - (__nv_bfloat16*)symm_weights.data_ptr(), - N_array.data_ptr(), - Offsets_array.data_ptr(), - (__nv_bfloat16*)err_tensor.data_ptr(), - lam - ); -} - -void launch_gather_kernel( - torch::Tensor symm_weights_ptrs, torch::Tensor out, - torch::Tensor rank_offsets, torch::Tensor rank_sizes, - int world_size -) { - int threads = 1024; - int blocks = world_size; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_kernel<<>>( - symm_weights_ptrs.data_ptr(), - (__nv_bfloat16*)out.data_ptr(), - rank_offsets.data_ptr(), - rank_sizes.data_ptr(), - world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_kernel1", &launch_kernel1, "Kalman blockwise kernel 1"); - m.def("launch_kernel2", &launch_kernel2, "Kalman blockwise kernel 2"); - m.def("launch_gather_kernel", &launch_gather_kernel, "Kalman blockwise gather"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("deepmd_kalman_opt_ext", CUDA_SRC) - return _ext - -_cache = {} - -def _get_resources(H, weights, P): - if not dist.is_initialized(): - world_size = 1 - else: - world_size = dist.get_world_size() - - device = weights[0].device - local_shapes = tuple(w.shape[0] for w in weights) - cache_key = (world_size, local_shapes) - - if cache_key in _cache: - return _cache[cache_key] - - weights_num = len(weights) - if weights_num > 0: - N_array = torch.tensor(local_shapes, dtype=torch.int32, device=device) - Offsets_array = torch.cat([ - torch.tensor([0], device=device, dtype=torch.int32), - torch.cumsum(N_array, dim=0)[:-1].to(torch.int32) - ]) - else: - N_array = torch.empty(0, dtype=torch.int32, device=device) - Offsets_array = torch.empty(0, dtype=torch.int32, device=device) - - total_local_weights = sum(local_shapes) - - if world_size > 1: - shape_list = [None for _ in range(world_size)] - dist.all_gather_object(shape_list, list(local_shapes)) - else: - shape_list = [list(local_shapes)] - - all_split_sizes = [] - rank_sizes = [] - for shapes in shape_list: - all_split_sizes.extend(shapes) - rank_sizes.append(sum(shapes)) - - total_world_weights = sum(rank_sizes) - - symm_tmp_buf = symm_mem.empty((1,), dtype=torch.float32, device=device) - if world_size > 1: - hdl_tmp = symm_mem.rendezvous(symm_tmp_buf, dist.group.WORLD) - symm_tmp_ptrs = torch.tensor(hdl_tmp.buffer_ptrs, dtype=torch.int64, device=device) - else: - hdl_tmp = None - symm_tmp_ptrs = torch.tensor([symm_tmp_buf.data_ptr()], dtype=torch.int64, device=device) - - if total_local_weights > 0: - symm_weights_buf = symm_mem.empty((total_local_weights,), dtype=torch.bfloat16, device=device) - else: - symm_weights_buf = torch.empty((0,), dtype=torch.bfloat16, device=device) - - if world_size > 1 and total_local_weights > 0: - hdl_weights = symm_mem.rendezvous(symm_weights_buf, dist.group.WORLD) - symm_weights_ptrs = torch.tensor(hdl_weights.buffer_ptrs, dtype=torch.int64, device=device) - else: - hdl_weights = None - symm_weights_ptrs = torch.tensor([symm_weights_buf.data_ptr() if total_local_weights > 0 else 0], dtype=torch.int64, device=device) - - rank_offsets = [0] - for s in rank_sizes[:-1]: - rank_offsets.append(rank_offsets[-1] + s) - - rank_sizes_tensor = torch.tensor(rank_sizes, dtype=torch.int32, device=device) - rank_offsets_tensor = torch.tensor(rank_offsets, dtype=torch.int32, device=device) - - K_flat = torch.empty(total_local_weights, dtype=torch.bfloat16, device=device) - out_gathered = torch.empty(total_world_weights, dtype=torch.bfloat16, device=device) - - res = { - "N_array": N_array, - "Offsets_array": Offsets_array, - "all_split_sizes": all_split_sizes, - "symm_tmp_buf": symm_tmp_buf, - "hdl_tmp": hdl_tmp, - "symm_tmp_ptrs": symm_tmp_ptrs, - "symm_weights_buf": symm_weights_buf, - "hdl_weights": hdl_weights, - "symm_weights_ptrs": symm_weights_ptrs, - "rank_sizes_tensor": rank_sizes_tensor, - "rank_offsets_tensor": rank_offsets_tensor, - "K_flat": K_flat, - "out_gathered": out_gathered, - "world_size": world_size, - } - _cache[cache_key] = res - return res - -@torch.no_grad() -def solution( - H: List[torch.Tensor], - error: torch.Tensor, - weights: List[torch.Tensor], - P: List[torch.Tensor], - kalman_lambda: float, - kalman_nue: float = 0.9987, -) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: - - weights_num = len(weights) - lam_val = float(kalman_lambda) - lam_next_val = kalman_nue * lam_val + 1.0 - kalman_nue - - if weights_num == 0: - if dist.is_initialized() and dist.get_world_size() > 1: - device = error.device - shape_list = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(shape_list, []) - return [], P, torch.tensor(lam_next_val, dtype=torch.bfloat16, device=device) - return [], P, torch.tensor(lam_next_val, dtype=torch.bfloat16, device=error.device if error is not None else torch.device("cuda")) - - device = weights[0].device - res = _get_resources(H, weights, P) - - P_ptrs_dev = torch.tensor([p.data_ptr() for p in P], dtype=torch.int64, device=device) - H_ptrs_dev = torch.tensor([h.data_ptr() for h in H], dtype=torch.int64, device=device) - W_ptrs_dev = torch.tensor([w.data_ptr() for w in weights], dtype=torch.int64, device=device) - - res["symm_tmp_buf"].zero_() - err_dev = error.to(device=device, dtype=torch.bfloat16) - - ext = _get_ext() - - ext.launch_kernel1( - P_ptrs_dev, H_ptrs_dev, res["K_flat"], - res["N_array"], res["Offsets_array"], lam_val, - res["symm_tmp_buf"], weights_num - ) - - if res["hdl_tmp"] is not None: - res["hdl_tmp"].barrier(channel=0) - - ext.launch_kernel2( - res["symm_tmp_ptrs"], res["world_size"], - P_ptrs_dev, W_ptrs_dev, res["K_flat"], - res["symm_weights_buf"], res["N_array"], res["Offsets_array"], - err_dev, lam_val, weights_num - ) - - if res["hdl_weights"] is not None: - res["hdl_weights"].barrier(channel=0) - ext.launch_gather_kernel( - res["symm_weights_ptrs"], res["out_gathered"], - res["rank_offsets_tensor"], res["rank_sizes_tensor"], - res["world_size"] - ) - out = res["out_gathered"] - else: - out = res["symm_weights_buf"] - - if len(res["all_split_sizes"]) > 0: - gathered_tensors = torch.split(out, res["all_split_sizes"]) - weights_out = [t.view(-1, 1) for t in gathered_tensors] - else: - weights_out = [] - - lam_next_tensor = torch.tensor(lam_next_val, dtype=weights[0].dtype, device=device) - - return weights_out, P, lam_next_tensor \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_cuda.py deleted file mode 100755 index b27bd0b..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/65_gnn_neighbor_sampling_cuda.py +++ /dev/null @@ -1,514 +0,0 @@ -""" -Strategy: -- Replaced opaque CPU/Python collectives (`dist.all_to_all_single`) and costly sorts with custom CUDA device-side communication. -- We utilize `torch.distributed._symmetric_memory` to allocate UVA buffer pools for NVLink P2P access. -- Instead of packing, sending, and reorganizing data per hop, the query rank writes queries directly into peer memory. - The serving rank samples directly to the requesting rank's final reply buffer via calculated offsets. -- CPU-side deduplication via `np.unique` (which bottlenecks due to syncs and transfers) is fully replaced - with a multi-pass custom CUDA atomic hash table that perfectly preserves the index-order of new elements - at hundreds of GB/s, matching PyTorch's reference semantics identically. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import List, Optional, Tuple -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -struct PCG32 { - uint64_t state; - uint64_t inc; - __device__ uint32_t next() { - uint64_t oldstate = state; - state = oldstate * 6364136223846793005ULL + inc; - uint32_t xorshifted = ((oldstate >> 18u) ^ oldstate) >> 27u; - uint32_t rot = oldstate >> 59u; - return (xorshifted >> rot) | (xorshifted << ((-rot) & 31)); - } - __device__ uint32_t bound(uint32_t range) { - uint32_t x = next(); - uint64_t m = uint64_t(x) * uint64_t(range); - uint32_t l = uint32_t(m); - if (l < range) { - uint32_t t = -range; - if (t >= range) { - t -= range; - if (t >= range) t %= range; - } - while (l < t) { - x = next(); - m = uint64_t(x) * uint64_t(range); - l = uint32_t(m); - } - } - return m >> 32; - } -}; - -__global__ void push_counts_kernel( - const int64_t* __restrict__ query_counts, - const int64_t* __restrict__ ptrs, - int rank, int world_size, int64_t o_mbox_counts) -{ - int j = threadIdx.x; - if (j < world_size) { - int64_t* remote_mbox = (int64_t*)ptrs[j]; - remote_mbox[o_mbox_counts + rank] = query_counts[j]; - } -} - -__global__ void init_offsets_kernel( - int64_t* __restrict__ current_offsets, - const int64_t* __restrict__ ptrs, - int rank, int world_size, int64_t o_mbox_offsets) -{ - int j = threadIdx.x; - if (j < world_size) { - int64_t* remote_mbox_offsets = (int64_t*)ptrs[j]; - current_offsets[j] = remote_mbox_offsets[o_mbox_offsets + rank]; - } -} - -__global__ void push_queries_kernel( - const int64_t* __restrict__ src, - const int32_t* __restrict__ owners, - int64_t S, - const int64_t* __restrict__ ptrs, - int64_t* __restrict__ current_offsets, - int rank, - int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node) -{ - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < S) { - int32_t owner = owners[idx]; - unsigned long long* offset_ptr = (unsigned long long*)¤t_offsets[owner]; - int64_t offset = atomicAdd(offset_ptr, 1ULL); - - int64_t* remote_buf = (int64_t*)ptrs[owner]; - remote_buf[o_q_src_rank + offset] = rank; - remote_buf[o_q_orig_idx + offset] = idx; - remote_buf[o_q_node + offset] = src[idx]; - } -} - -__global__ void process_queries_pass1_kernel( - int64_t total_queries, - const int64_t* __restrict__ symm_buf, - const int64_t* __restrict__ local_adj_row_ptr, - int64_t fanout, bool replace, - const int64_t* __restrict__ ptrs, - int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node, int64_t o_r_counts) -{ - int64_t q = blockIdx.x * blockDim.x + threadIdx.x; - if (q < total_queries) { - int64_t src_rank = symm_buf[o_q_src_rank + q]; - int64_t orig_idx = symm_buf[o_q_orig_idx + q]; - int64_t node = symm_buf[o_q_node + q]; - - int64_t start = local_adj_row_ptr[node]; - int64_t end = local_adj_row_ptr[node+1]; - int64_t degree = end - start; - - int64_t c = 0; - if (degree > 0) { - if (fanout < 0) c = degree; - else if (replace) c = fanout; - else c = (fanout < degree) ? fanout : degree; - } - - int64_t* remote_buf = (int64_t*)ptrs[src_rank]; - remote_buf[o_r_counts + orig_idx] = c; - } -} - -__global__ void process_queries_pass2_kernel( - int64_t total_queries, - const int64_t* __restrict__ symm_buf, - const int64_t* __restrict__ local_adj_row_ptr, - const int64_t* __restrict__ local_adj_col, - int64_t fanout, bool replace, uint64_t seed, - const int64_t* __restrict__ ptrs, - int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node, - int64_t o_r_offsets, int64_t o_r_nodes, int64_t o_r_edges) -{ - int64_t q = blockIdx.x * blockDim.x + threadIdx.x; - if (q < total_queries) { - int64_t src_rank = symm_buf[o_q_src_rank + q]; - int64_t orig_idx = symm_buf[o_q_orig_idx + q]; - int64_t node = symm_buf[o_q_node + q]; - - int64_t start = local_adj_row_ptr[node]; - int64_t end = local_adj_row_ptr[node+1]; - int64_t degree = end - start; - - int64_t c = 0; - if (degree > 0) { - if (fanout < 0) c = degree; - else if (replace) c = fanout; - else c = (fanout < degree) ? fanout : degree; - } - - if (c > 0) { - int64_t* remote_buf = (int64_t*)ptrs[src_rank]; - int64_t offset = remote_buf[o_r_offsets + orig_idx]; - - PCG32 rng; - rng.state = seed + q; - rng.inc = (q << 1) | 1; - rng.next(); - - if (c == degree && !replace) { - for (int64_t i = 0; i < c; ++i) { - remote_buf[o_r_nodes + offset + i] = local_adj_col[start + i]; - remote_buf[o_r_edges + offset + i] = start + i; - } - } else if (replace) { - for (int64_t i = 0; i < c; ++i) { - int64_t r = rng.bound((uint32_t)degree); - remote_buf[o_r_nodes + offset + i] = local_adj_col[start + r]; - remote_buf[o_r_edges + offset + i] = start + r; - } - } else { - int local_sel[256]; - for (int64_t i = 0; i < c; ++i) { - int64_t r; - if (i < 256) { - bool duplicate; - do { - r = rng.bound((uint32_t)degree); - duplicate = false; - for (int64_t j = 0; j < i; ++j) { - if (local_sel[j] == r) { duplicate = true; break; } - } - } while (duplicate); - local_sel[i] = r; - } else { - r = rng.bound((uint32_t)degree); - } - remote_buf[o_r_nodes + offset + i] = local_adj_col[start + r]; - remote_buf[o_r_edges + offset + i] = start + r; - } - } - } - } -} - -__global__ void hash_insert_history_kernel( - const int64_t* __restrict__ history, int64_t history_size, - int64_t* __restrict__ keys, int64_t* __restrict__ values, int64_t table_size) -{ - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < history_size) { - int64_t node = history[idx]; - int64_t slot = (node * 11400714819323198485ULL) % table_size; - while (true) { - int64_t old = atomicCAS((unsigned long long*)&keys[slot], -1ULL, (unsigned long long)node); - if (old == -1 || old == node) { - values[slot] = -2; - break; - } - slot = (slot + 1) % table_size; - } - } -} - -__global__ void hash_insert_new_kernel( - const int64_t* __restrict__ new_nodes, int64_t new_size, - int64_t* __restrict__ keys, int64_t* __restrict__ values, int64_t table_size) -{ - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < new_size) { - int64_t node = new_nodes[idx]; - int64_t slot = (node * 11400714819323198485ULL) % table_size; - while (true) { - int64_t old = atomicCAS((unsigned long long*)&keys[slot], -1ULL, (unsigned long long)node); - if (old == -1 || old == node) { - atomicMin((unsigned long long*)&values[slot], (unsigned long long)idx); - break; - } - slot = (slot + 1) % table_size; - } - } -} - -__global__ void hash_check_kernel( - const int64_t* __restrict__ new_nodes, int64_t new_size, - const int64_t* __restrict__ keys, const int64_t* __restrict__ values, int64_t table_size, - int64_t* __restrict__ mask) -{ - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < new_size) { - int64_t node = new_nodes[idx]; - int64_t slot = (node * 11400714819323198485ULL) % table_size; - while (true) { - if (keys[slot] == node) { - mask[idx] = (values[slot] == idx) ? 1 : 0; - break; - } - slot = (slot + 1) % table_size; - } - } -} - -__global__ void hash_extract_kernel( - const int64_t* __restrict__ new_nodes, int64_t new_size, - const int64_t* __restrict__ mask, const int64_t* __restrict__ mask_sum, - int64_t* __restrict__ out_src) -{ - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < new_size) { - if (mask[idx] == 1) { - out_src[mask_sum[idx] - 1] = new_nodes[idx]; - } - } -} - -void launch_push_counts(torch::Tensor query_counts, torch::Tensor ptrs, int rank, int world_size, int64_t o_mbox_counts) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - push_counts_kernel<<<1, world_size, 0, stream>>>(query_counts.data_ptr(), ptrs.data_ptr(), rank, world_size, o_mbox_counts); -} - -void launch_push_queries(torch::Tensor src, torch::Tensor owners, int64_t S, torch::Tensor ptrs, torch::Tensor current_offsets, int rank, int world_size, int64_t o_mbox_offsets, int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - init_offsets_kernel<<<1, world_size, 0, stream>>>(current_offsets.data_ptr(), ptrs.data_ptr(), rank, world_size, o_mbox_offsets); - int threads = 256; - int blocks = (S + threads - 1) / threads; - if (blocks > 0) { - push_queries_kernel<<>>(src.data_ptr(), owners.data_ptr(), S, ptrs.data_ptr(), current_offsets.data_ptr(), rank, o_q_src_rank, o_q_orig_idx, o_q_node); - } -} - -void launch_pass1(int64_t total_queries, torch::Tensor symm_buf, torch::Tensor local_adj_row_ptr, int64_t fanout, bool replace, torch::Tensor ptrs, int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node, int64_t o_r_counts) { - if (total_queries == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (total_queries + threads - 1) / threads; - process_queries_pass1_kernel<<>>(total_queries, symm_buf.data_ptr(), local_adj_row_ptr.data_ptr(), fanout, replace, ptrs.data_ptr(), o_q_src_rank, o_q_orig_idx, o_q_node, o_r_counts); -} - -void launch_pass2(int64_t total_queries, torch::Tensor symm_buf, torch::Tensor local_adj_row_ptr, torch::Tensor local_adj_col, int64_t fanout, bool replace, uint64_t seed, torch::Tensor ptrs, int64_t o_q_src_rank, int64_t o_q_orig_idx, int64_t o_q_node, int64_t o_r_offsets, int64_t o_r_nodes, int64_t o_r_edges) { - if (total_queries == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (total_queries + threads - 1) / threads; - process_queries_pass2_kernel<<>>(total_queries, symm_buf.data_ptr(), local_adj_row_ptr.data_ptr(), local_adj_col.data_ptr(), fanout, replace, seed, ptrs.data_ptr(), o_q_src_rank, o_q_orig_idx, o_q_node, o_r_offsets, o_r_nodes, o_r_edges); -} - -void launch_dedup_history(torch::Tensor history, torch::Tensor keys, torch::Tensor values, int64_t table_size) { - if (history.numel() == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (history.numel() + threads - 1) / threads; - hash_insert_history_kernel<<>>(history.data_ptr(), history.numel(), keys.data_ptr(), values.data_ptr(), table_size); -} - -void launch_dedup_new(torch::Tensor new_nodes, torch::Tensor keys, torch::Tensor values, int64_t table_size) { - if (new_nodes.numel() == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (new_nodes.numel() + threads - 1) / threads; - hash_insert_new_kernel<<>>(new_nodes.data_ptr(), new_nodes.numel(), keys.data_ptr(), values.data_ptr(), table_size); -} - -void launch_dedup_check(torch::Tensor new_nodes, torch::Tensor keys, torch::Tensor values, int64_t table_size, torch::Tensor mask) { - if (new_nodes.numel() == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (new_nodes.numel() + threads - 1) / threads; - hash_check_kernel<<>>(new_nodes.data_ptr(), new_nodes.numel(), keys.data_ptr(), values.data_ptr(), table_size, mask.data_ptr()); -} - -void launch_dedup_extract(torch::Tensor new_nodes, torch::Tensor mask, torch::Tensor mask_sum, torch::Tensor out_src) { - if (new_nodes.numel() == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = (new_nodes.numel() + threads - 1) / threads; - hash_extract_kernel<<>>(new_nodes.data_ptr(), new_nodes.numel(), mask.data_ptr(), mask_sum.data_ptr(), out_src.data_ptr()); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push_counts", &launch_push_counts); - m.def("launch_push_queries", &launch_push_queries); - m.def("launch_pass1", &launch_pass1); - m.def("launch_pass2", &launch_pass2); - m.def("launch_dedup_history", &launch_dedup_history); - m.def("launch_dedup_new", &launch_dedup_new); - m.def("launch_dedup_check", &launch_dedup_check); - m.def("launch_dedup_extract", &launch_dedup_extract); -} -''' - -_ext = None -def get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gnn_symm_sample_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def get_symm_state(world_size: int, device: torch.device, group: dist.ProcessGroup): - key = (world_size, device) - if key in _symm_cache: - return _symm_cache[key] - - MAX_Q = 10_000_000 - MAX_S = 10_000_000 - MAX_R = 50_000_000 - - O_MBOX_COUNTS = 0 - O_MBOX_OFFSETS = O_MBOX_COUNTS + world_size - O_Q_SRC_RANK = O_MBOX_OFFSETS + world_size - O_Q_ORIG_IDX = O_Q_SRC_RANK + MAX_Q - O_Q_NODE = O_Q_ORIG_IDX + MAX_Q - O_R_COUNTS = O_Q_NODE + MAX_Q - O_R_OFFSETS = O_R_COUNTS + MAX_S - O_R_NODES = O_R_OFFSETS + MAX_S - O_R_EDGES = O_R_NODES + MAX_R - TOTAL_SYMM_SIZE = O_R_EDGES + MAX_R - - buf = symm_mem.empty(TOTAL_SYMM_SIZE, dtype=torch.int64, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - offsets = { - 'O_MBOX_COUNTS': O_MBOX_COUNTS, - 'O_MBOX_OFFSETS': O_MBOX_OFFSETS, - 'O_Q_SRC_RANK': O_Q_SRC_RANK, - 'O_Q_ORIG_IDX': O_Q_ORIG_IDX, - 'O_Q_NODE': O_Q_NODE, - 'O_R_COUNTS': O_R_COUNTS, - 'O_R_OFFSETS': O_R_OFFSETS, - 'O_R_NODES': O_R_NODES, - 'O_R_EDGES': O_R_EDGES, - } - - _symm_cache[key] = (buf, hdl, ptrs_tensor, offsets) - return _symm_cache[key] - - -def _relabel_neighborhood(node, dst_with_dupl, node_with_dupl): - if node_with_dupl.numel() == 0: - return node.new_empty(0), node.new_empty(0) - assoc = torch.full((int(node.max().item()) + 1,), -1, dtype=torch.long, device=node.device) - assoc[node] = torch.arange(node.numel(), device=node.device) - row = assoc[node_with_dupl] - col = assoc[dst_with_dupl] - return row, col - - -@torch.no_grad() -def solution( - seed_nodes: torch.Tensor, - fanouts: List[int], - local_adj_row_ptr: torch.Tensor, - local_adj_col: torch.Tensor, - node_to_rank: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, - replace: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - device = seed_nodes.device - - ext = get_ext() - symm_buf, hdl, ptrs_tensor, o = get_symm_state(world_size, device, group) - current_offsets = torch.empty(world_size, dtype=torch.long, device=device) - - seed = seed_nodes.to(dtype=torch.long, device=device) - src = seed.clone() - node = src.clone() - node_with_dupl = [seed.new_empty(0)] - dst_with_dupl = [seed.new_empty(0)] - edge = [seed.new_empty(0)] - - for fanout in fanouts: - if src.numel() == 0: - break - - S = src.numel() - owners = node_to_rank[src].to(torch.int32) - - # 1. P2P write query counts to peers - query_counts = torch.bincount(owners, minlength=world_size) - ext.launch_push_counts(query_counts, ptrs_tensor, rank, world_size, o['O_MBOX_COUNTS']) - dist.barrier(group) - - # 2. Local read mailbox, prefix sum to find placement offsets - my_recv_counts = symm_buf[o['O_MBOX_COUNTS'] : o['O_MBOX_COUNTS'] + world_size] - my_query_offsets = torch.cat([torch.zeros(1, dtype=torch.long, device=device), my_recv_counts.cumsum(0)[:-1]]) - symm_buf[o['O_MBOX_OFFSETS'] : o['O_MBOX_OFFSETS'] + world_size] = my_query_offsets - total_queries = int(my_recv_counts.sum().item()) - dist.barrier(group) - - # 3. P2P Push Queries into destination bins - ext.launch_push_queries( - src, owners, S, ptrs_tensor, current_offsets, rank, world_size, - o['O_MBOX_OFFSETS'], o['O_Q_SRC_RANK'], o['O_Q_ORIG_IDX'], o['O_Q_NODE'] - ) - dist.barrier(group) - - # 4. Pass 1: Compute reply counts, P2P write sizes directly to origin's reply array buffer - ext.launch_pass1( - total_queries, symm_buf, local_adj_row_ptr, int(fanout), replace, ptrs_tensor, - o['O_Q_SRC_RANK'], o['O_Q_ORIG_IDX'], o['O_Q_NODE'], o['O_R_COUNTS'] - ) - dist.barrier(group) - - # 5. Local prefix sum for reply mapping - my_reply_counts = symm_buf[o['O_R_COUNTS'] : o['O_R_COUNTS'] + S] - my_reply_offsets = torch.cat([torch.zeros(1, dtype=torch.long, device=device), my_reply_counts.cumsum(0)[:-1]]) - total_replies = int((my_reply_counts.sum() if S > 0 else 0).item()) - symm_buf[o['O_R_OFFSETS'] : o['O_R_OFFSETS'] + S] = my_reply_offsets - dist.barrier(group) - - # 6. Pass 2: Actually sample edges and push into exactly-mapped destination offsets - pass2_seed = int(torch.randint(0, 2**30, (1,)).item()) - ext.launch_pass2( - total_queries, symm_buf, local_adj_row_ptr, local_adj_col, int(fanout), replace, pass2_seed, ptrs_tensor, - o['O_Q_SRC_RANK'], o['O_Q_ORIG_IDX'], o['O_Q_NODE'], o['O_R_OFFSETS'], o['O_R_NODES'], o['O_R_EDGES'] - ) - dist.barrier(group) - - # 7. Collect results - out_node = symm_buf[o['O_R_NODES'] : o['O_R_NODES'] + total_replies] - out_edge = symm_buf[o['O_R_EDGES'] : o['O_R_EDGES'] + total_replies] - out_dst = torch.repeat_interleave(src, my_reply_counts) - - if out_node.numel() == 0: - break - - # 8. Hash-Based Device Deduplication vs np.unique equivalent - table_size = (node.numel() + out_node.numel()) * 2 + 1024 - keys = torch.full((table_size,), -1, dtype=torch.long, device=device) - values = torch.full((table_size,), 2**62, dtype=torch.long, device=device) - - ext.launch_dedup_history(node, keys, values, table_size) - ext.launch_dedup_new(out_node, keys, values, table_size) - - mask = torch.zeros(out_node.numel(), dtype=torch.long, device=device) - ext.launch_dedup_check(out_node, keys, values, table_size, mask) - - mask_sum = mask.cumsum(dim=0) - out_src = torch.empty(out_node.numel(), dtype=torch.long, device=device) - ext.launch_dedup_extract(out_node, mask, mask_sum, out_src) - - total_new = int(mask_sum[-1].item()) if mask.numel() > 0 else 0 - src = out_src[:total_new] - node = torch.cat([node, src]) - - node_with_dupl.append(out_node) - dst_with_dupl.append(out_dst) - edge.append(out_edge) - - node_dupl = torch.cat(node_with_dupl) - dst_dupl = torch.cat(dst_with_dupl) - row, col = _relabel_neighborhood(node, dst_dupl, node_dupl) - return node, row, col, torch.cat(edge) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_cuda.py deleted file mode 100755 index 7438fed..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/66_gnn_feature_exchange_all2all_cuda.py +++ /dev/null @@ -1,289 +0,0 @@ -""" -Optimized GraphBolt feature exchange using symmetric memory and custom CUDA P2P. - -Strategy: -- Persistent Symmetric Memory: We maintain persistent symmetric memory buffers for metadata - (`meta_buf`) and the output features (`out_buf`). Sizes are dynamically managed, communicating - reallocation requests via device-side `meta_buf` flags to minimize CPU/NCCL overhead. -- Fused Gather & Push: Instead of gathering features locally into an intermediate buffer and - running an all-to-all collective, a custom CUDA kernel directly gathers rows from `local_features` - using `seed_inverse_ids` and pushes them to the correct remote `out_buf`s via NVLink P2P stores. -- Compute-Comm Overlap: P2P stores are issued concurrently with index calculations and memory loads - from local HBM, fully utilizing memory/NVLink parallelism without blocking on bulk collectives. -""" - -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void prefetch_and_check_kernel( - const int64_t* const* meta_ptrs, - int n, - int rank, - int64_t* local_dst_offsets, - int* global_needs_realloc -) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - int needs = 0; - for (int r = 0; r < n; ++r) { - if (meta_ptrs[r][n] == 1) { - needs = 1; - } - } - *global_needs_realloc = needs; - - // Prefetch base offsets for chunks we will push to remote peers - for (int j = 0; j < n; ++j) { - int dst = (j + rank) % n; - int k = (n - dst + rank) % n; - local_dst_offsets[j] = meta_ptrs[dst][k]; - } - } -} - -void launch_prefetch_and_check( - torch::Tensor meta_ptrs, - int n, - int rank, - torch::Tensor local_dst_offsets, - torch::Tensor global_needs_realloc -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - prefetch_and_check_kernel<<<1, 1, 0, stream>>>( - (const int64_t* const*)meta_ptrs.data_ptr(), - n, rank, - local_dst_offsets.data_ptr(), - global_needs_realloc.data_ptr() - ); -} - -template -__global__ void fused_gather_push_kernel( - const T* __restrict__ local_features, - const int64_t* __restrict__ seed_inverse_ids, - const int64_t* __restrict__ src_offsets, - const int64_t* __restrict__ local_dst_offsets, - T* const* out_ptrs, - int rank, - int n, - int64_t H_vec, - int64_t total_elements_vec -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = idx; i < total_elements_vec; i += gridDim.x * blockDim.x) { - int64_t row = i / H_vec; - int64_t col = i % H_vec; - - int j = 0; - #pragma unroll - for (int step = 0; step < 8; ++step) { - if (step < n && row >= src_offsets[step+1]) { - j = step + 1; - } - } - - int dst = (j + rank) % n; - int64_t offset_in_chunk = row - src_offsets[j]; - int64_t dst_base = local_dst_offsets[j]; - int64_t out_row = dst_base + offset_in_chunk; - - int64_t src_row = seed_inverse_ids[row]; - T val = local_features[src_row * H_vec + col]; - - // P2P direct write to the destination rank's symmetric buffer - out_ptrs[dst][out_row * H_vec + col] = val; - } -} - -void launch_fused_gather_push( - torch::Tensor local_features, - torch::Tensor seed_inverse_ids, - torch::Tensor src_offsets, - torch::Tensor local_dst_offsets, - torch::Tensor out_ptrs, - int rank, - int n -) { - int64_t N_send = seed_inverse_ids.size(0); - int64_t H = local_features.size(1); - - if (N_send == 0) return; - - int64_t element_size = local_features.element_size(); - int64_t H_bytes = H * element_size; - TORCH_CHECK(H_bytes % 2 == 0, "Feature row size in bytes must be multiple of 2"); - int64_t H_units = H_bytes / 2; - - // Choose optimal vectorized load/store alignment based on inner dimension - int vec_size = 1; - if (H_units % 8 == 0) vec_size = 8; - else if (H_units % 4 == 0) vec_size = 4; - else if (H_units % 2 == 0) vec_size = 2; - - int64_t H_vec = H_units / vec_size; - int64_t total_elements_vec = N_send * H_vec; - - int threads = 256; - int blocks = (total_elements_vec + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const int64_t* seed_inverse_ids_ptr = seed_inverse_ids.data_ptr(); - const int64_t* src_offsets_ptr = src_offsets.data_ptr(); - const int64_t* local_dst_offsets_ptr = local_dst_offsets.data_ptr(); - const void* out_ptrs_ptr = out_ptrs.data_ptr(); - const void* local_features_ptr = local_features.data_ptr(); - - if (vec_size == 8) { - fused_gather_push_kernel<<>>( - (const uint4*)local_features_ptr, - seed_inverse_ids_ptr, - src_offsets_ptr, - local_dst_offsets_ptr, - (uint4* const*)out_ptrs_ptr, - rank, n, H_vec, total_elements_vec - ); - } else if (vec_size == 4) { - fused_gather_push_kernel<<>>( - (const uint2*)local_features_ptr, - seed_inverse_ids_ptr, - src_offsets_ptr, - local_dst_offsets_ptr, - (uint2* const*)out_ptrs_ptr, - rank, n, H_vec, total_elements_vec - ); - } else if (vec_size == 2) { - fused_gather_push_kernel<<>>( - (const uint32_t*)local_features_ptr, - seed_inverse_ids_ptr, - src_offsets_ptr, - local_dst_offsets_ptr, - (uint32_t* const*)out_ptrs_ptr, - rank, n, H_vec, total_elements_vec - ); - } else { - fused_gather_push_kernel<<>>( - (const uint16_t*)local_features_ptr, - seed_inverse_ids_ptr, - src_offsets_ptr, - local_dst_offsets_ptr, - (uint16_t* const*)out_ptrs_ptr, - rank, n, H_vec, total_elements_vec - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_prefetch_and_check", &launch_prefetch_and_check, "Prefetch metadata and check reallocation"); - m.def("launch_fused_gather_push", &launch_fused_gather_push, "Fused gather and push over symmetric memory"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gnn_feat_exchange_ext", CUDA_SRC) - return _ext - -class SymmMemState: - def __init__(self, n: int, device: torch.device): - self.n = n - self.device = device - self.meta_buf = symm_mem.empty(n + 1, dtype=torch.int64, device=device) - self.meta_hdl = symm_mem.rendezvous(self.meta_buf) - self.meta_ptrs = torch.tensor(self.meta_hdl.buffer_ptrs, dtype=torch.int64, device=device) - self.out_buf = None - self.out_hdl = None - self.out_ptrs = None - self.my_max_rows = 0 - -_state = None - -@torch.no_grad() -def solution( - local_features: torch.Tensor, - seed_inverse_ids: torch.Tensor, - counts_sent: List[int], - counts_received: List[int], - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - n = dist.get_world_size(group) - rank = dist.get_rank(group) - device = local_features.device - - global _state - if _state is None: - _state = SymmMemState(n, device) - _get_ext() - - my_rows = sum(counts_sent) - my_needs_realloc = 1 if (my_rows > _state.my_max_rows or _state.out_buf is None) else 0 - - dst_offsets = [0] * n - curr = 0 - for i in range(n): - dst_offsets[i] = curr - curr += counts_sent[i] - - meta_local = dst_offsets + [my_needs_realloc] - meta_cpu = torch.tensor(meta_local, dtype=torch.int64, pin_memory=True) - _state.meta_buf.copy_(meta_cpu, non_blocking=True) - - # Wait for all ranks to expose their destination offsets and dynamic allocation needs - _state.meta_hdl.barrier(channel=0) - - global_needs_realloc_dev = torch.empty(1, dtype=torch.int32, device=device) - local_dst_offsets = torch.empty(n, dtype=torch.int64, device=device) - - _get_ext().launch_prefetch_and_check( - _state.meta_ptrs, n, rank, local_dst_offsets, global_needs_realloc_dev - ) - - # Conditionally expand symmetric output buffer capacity without extra NCCL blocking logic - if global_needs_realloc_dev.item() == 1: - new_max = max(my_rows, int(_state.my_max_rows * 1.5)) - if new_max < 1024: - new_max = 1024 - if _state.out_buf is None or new_max > _state.my_max_rows: - _state.out_buf = symm_mem.empty((new_max, local_features.size(1)), dtype=local_features.dtype, device=device) - _state.my_max_rows = new_max - - _state.out_hdl = symm_mem.rendezvous(_state.out_buf, group=group) - _state.out_ptrs = torch.tensor(_state.out_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - src_offsets_list = [0] * (n + 1) - curr = 0 - for i in range(n): - src_offsets_list[i] = curr - curr += counts_received[i] - src_offsets_list[n] = curr - src_offsets = torch.tensor(src_offsets_list, dtype=torch.int64, device=device) - - _get_ext().launch_fused_gather_push( - local_features, - seed_inverse_ids.contiguous(), - src_offsets, - local_dst_offsets, - _state.out_ptrs, - rank, - n - ) - - # Wait for all remote pushes to our local symmetric out_buf to complete - _state.out_hdl.barrier(channel=0) - - # Return cloned sub-tensor matching GraphBolt output/mutability semantics - return _state.out_buf[:my_rows, :].clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_cuda.py deleted file mode 100755 index 33b7bff..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/67_gnn_feature_exchange_all2all_backward_cuda.py +++ /dev/null @@ -1,197 +0,0 @@ -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void pull_scatter_add_bf16_kernel( - const int64_t* __restrict__ ptrs, - const int32_t* __restrict__ recv_offsets, - const int32_t* __restrict__ remote_offsets, - const int32_t* __restrict__ peers, - const int64_t* __restrict__ seed_inverse_ids, - __nv_bfloat16* __restrict__ grad_input, - int total_recv, - int H, - int world_size -) { - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if (idx >= total_recv) return; - - int chunk_i = 0; - for (int i = 1; i < world_size; i++) { - if (idx >= recv_offsets[i]) { - chunk_i = i; - } - } - - int offset_in_chunk = idx - recv_offsets[chunk_i]; - int remote_idx = remote_offsets[chunk_i] + offset_in_chunk; - int peer = peers[chunk_i]; - - // Establish mapping to symmetric peer memory - const __nv_bfloat16* remote_row = (const __nv_bfloat16*)ptrs[peer] + remote_idx * H; - int dst_row = seed_inverse_ids[idx]; - __nv_bfloat16* dst_ptr = grad_input + dst_row * H; - - // Vectorized path for aligned even-dimension counts (doubles throughput) - if (H % 2 == 0) { - int h = threadIdx.x * 2; - int stride = blockDim.x * 2; - for (; h < H; h += stride) { - __nv_bfloat162 val = *(__nv_bfloat162*)(remote_row + h); - atomicAdd((__nv_bfloat162*)(dst_ptr + h), val); - } - } else { - int h = threadIdx.x; - int stride = blockDim.x; - for (; h < H; h += stride) { - atomicAdd(dst_ptr + h, remote_row[h]); - } - } -} - -void launch_pull_scatter_add( - torch::Tensor ptrs_tensor, - torch::Tensor recv_offsets, - torch::Tensor remote_offsets, - torch::Tensor peers, - torch::Tensor seed_inverse_ids, - torch::Tensor grad_input, - int H -) { - int total_recv = seed_inverse_ids.size(0); - int world_size = ptrs_tensor.size(0); - - dim3 block(32, 8); - dim3 grid((total_recv + block.y - 1) / block.y); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - pull_scatter_add_bf16_kernel<<>>( - ptrs_tensor.data_ptr(), - recv_offsets.data_ptr(), - remote_offsets.data_ptr(), - peers.data_ptr(), - seed_inverse_ids.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(grad_input.data_ptr()), - total_recv, - H, - world_size - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pull_scatter_add", &launch_pull_scatter_add, "Pull scatter add kernel"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gnn_pull_scatter_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(max_elements, H, dtype, device, group): - key = (H, dtype, device) - if key in _symm_cache: - c = _symm_cache[key] - if c['max_elements'] >= max_elements: - return c['buf'], c['hdl'], c['ptrs'] - - # Pad to help resist frequent re-allocations as dynamic batches shift sizes - alloc_elements = max(max_elements, 1024) - alloc_elements = (alloc_elements + 1023) // 1024 * 1024 - - buf = symm_mem.empty((alloc_elements, H), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - - _symm_cache[key] = { - 'buf': buf, - 'hdl': hdl, - 'ptrs': ptrs, - 'max_elements': alloc_elements - } - return buf, hdl, ptrs - -@torch.no_grad() -def solution( - grad_output: torch.Tensor, - seed_inverse_ids: torch.Tensor, - seed_size: int, - counts_sent: List[int], - counts_received: List[int], - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - device = grad_output.device - dtype = grad_output.dtype - H = grad_output.shape[1] - - if rank == 0: - _get_ext() - dist.barrier(group=group) - - # Gather structural dimensions globally across all ranks - local_counts = torch.tensor(counts_sent + counts_received, dtype=torch.int32, device=device) - all_counts = torch.empty((world_size, len(local_counts)), dtype=torch.int32, device=device) - dist.all_gather_into_tensor(all_counts, local_counts, group=group) - - counts_sent_t = all_counts[:, :world_size] - max_elements = counts_sent_t.sum(dim=1).max().item() - local_sent_size = counts_sent_t[rank].sum().item() - - buf, hdl, ptrs = _get_symm_state(max_elements, H, dtype, device, group) - - # Wait strictly for reads over the last iteration to finish gracefully - hdl.barrier(channel=1) - - # Expose backwards payload into UVA symmetric region - if local_sent_size > 0: - buf[:local_sent_size].copy_(grad_output) - - # Synchronization before UVA remote pulls map onto memory mapping - hdl.barrier(channel=0) - - # Pure GPU-native prefix sums calculations to derive remote bounds explicitly - counts_received_t = all_counts[:, world_size:] - recv_offsets_local = torch.zeros(world_size + 1, dtype=torch.int32, device=device) - recv_offsets_local[1:] = torch.cumsum(counts_received_t[rank], dim=0) - - sent_offsets = torch.zeros((world_size, world_size + 1), dtype=torch.int32, device=device) - sent_offsets[:, 1:] = torch.cumsum(counts_sent_t, dim=1) - - i = torch.arange(world_size, dtype=torch.int32, device=device) - peers = (rank + i) % world_size - ks = (world_size - i) % world_size - remote_chunk_offsets = sent_offsets[peers, ks] - - grad_input = torch.zeros((seed_size, H), dtype=dtype, device=device) - total_recv = seed_inverse_ids.size(0) - - if total_recv > 0: - _get_ext().launch_pull_scatter_add( - ptrs, - recv_offsets_local, - remote_chunk_offsets, - peers, - seed_inverse_ids, - grad_input, - H - ) - - return grad_input \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_cuda.py deleted file mode 100755 index 36d9045..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/68_gnn_sparse_embedding_all2all_cuda.py +++ /dev/null @@ -1,264 +0,0 @@ -# Strategy: -# - Use PyTorch's native `bincount` and `argsort` for fast local partition calculation and tensor packing. -# - Overlap the metadata exchange (`all_gather_into_tensor` of `send_splits`) on a CUDA side-stream to hide communication latency behind the sorting/packing compute. -# - Maintain automatically cached and dynamically resized `torch.distributed._symmetric_memory` buffers, eliminating `rendezvous` overhead after warmup. -# - Launch custom unified CUDA kernels to push the packed indices and multi-dimensional values directly into peer receive buffers via symmetric memory UVA pointers, yielding maximum NVLink coalesced bandwidth. -# - Use device-side execution barriers (`hdl.barrier()`) for zero-overhead synchronization before and after the direct P2P writes. - -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void push_idx_kernel( - const T_idx* __restrict__ send_idx, - const int64_t* __restrict__ ptrs_idx, - const int64_t* __restrict__ send_offsets, - const int64_t* __restrict__ peer_recv_offsets, - int world_size -) { - int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - int64_t total_elements = send_offsets[world_size]; - if (tid >= total_elements) return; - - int peer = 0; - while (peer < world_size && tid >= send_offsets[peer + 1]) { - peer++; - } - - int64_t local_offset = tid - send_offsets[peer]; - int64_t remote_offset = peer_recv_offsets[peer] + local_offset; - - T_idx* remote_idx = reinterpret_cast(ptrs_idx[peer]); - remote_idx[remote_offset] = send_idx[tid]; -} - -template -__global__ void push_val_kernel( - const T_val* __restrict__ send_val, - const int64_t* __restrict__ ptrs_val, - const int64_t* __restrict__ send_offsets, - const int64_t* __restrict__ peer_recv_offsets, - int64_t D, - int world_size -) { - int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - int64_t total_elements = send_offsets[world_size]; - int64_t total_val_elements = total_elements * D; - if (tid >= total_val_elements) return; - - int64_t item_idx = tid / D; - int64_t d_idx = tid % D; - - int peer = 0; - while (peer < world_size && item_idx >= send_offsets[peer + 1]) { - peer++; - } - - int64_t local_offset = item_idx - send_offsets[peer]; - int64_t remote_item_offset = peer_recv_offsets[peer] + local_offset; - - int64_t remote_flat_offset = remote_item_offset * D + d_idx; - - T_val* remote_val = reinterpret_cast(ptrs_val[peer]); - remote_val[remote_flat_offset] = send_val[tid]; -} - -void launch_push( - torch::Tensor send_idx, - torch::Tensor send_val, - torch::Tensor ptrs_idx, - torch::Tensor ptrs_val, - torch::Tensor send_offsets, - torch::Tensor peer_recv_offsets, - int world_size -) { - int64_t total_elements = send_idx.numel(); - if (total_elements == 0) return; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - int threads = 256; - int blocks_idx = (total_elements + threads - 1) / threads; - - if (send_idx.scalar_type() == torch::kInt64) { - push_idx_kernel<<>>( - send_idx.data_ptr(), ptrs_idx.data_ptr(), - send_offsets.data_ptr(), peer_recv_offsets.data_ptr(), world_size); - } else if (send_idx.scalar_type() == torch::kInt32) { - push_idx_kernel<<>>( - send_idx.data_ptr(), ptrs_idx.data_ptr(), - send_offsets.data_ptr(), peer_recv_offsets.data_ptr(), world_size); - } else { - TORCH_CHECK(false, "Unsupported dtype for idx"); - } - - int64_t total_val_elements = send_val.numel(); - int64_t D = total_val_elements / total_elements; - int blocks_val = (total_val_elements + threads - 1) / threads; - - int elem_size = send_val.element_size(); - if (elem_size == 2) { - push_val_kernel<<>>( - reinterpret_cast(send_val.data_ptr()), - ptrs_val.data_ptr(), send_offsets.data_ptr(), - peer_recv_offsets.data_ptr(), D, world_size); - } else if (elem_size == 4) { - push_val_kernel<<>>( - reinterpret_cast(send_val.data_ptr()), - ptrs_val.data_ptr(), send_offsets.data_ptr(), - peer_recv_offsets.data_ptr(), D, world_size); - } else if (elem_size == 8) { - push_val_kernel<<>>( - reinterpret_cast(send_val.data_ptr()), - ptrs_val.data_ptr(), send_offsets.data_ptr(), - peer_recv_offsets.data_ptr(), D, world_size); - } else if (elem_size == 1) { - push_val_kernel<<>>( - reinterpret_cast(send_val.data_ptr()), - ptrs_val.data_ptr(), send_offsets.data_ptr(), - peer_recv_offsets.data_ptr(), D, world_size); - } else { - TORCH_CHECK(false, "Unsupported element size for values"); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push", &launch_push, "UVA Custom Push Kernel for AllToAll"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("dgl_sparse_push_uva_ext", CUDA_SRC) - return _ext - -_current_capacities = None -_current_D = None -_symm_cache = {} - -def get_symm_buffers(recv_counts: torch.Tensor, value_shape_tail: tuple, dtype_idx: torch.dtype, dtype_val: torch.dtype, device: torch.device, group: dist.ProcessGroup): - global _current_capacities, _current_D, _symm_cache - - world_size = dist.get_world_size(group) - D = 1 - for d in value_shape_tail: - D *= d - - if _current_capacities is None: - _current_capacities = torch.zeros(world_size, dtype=torch.long, device=device) - _symm_cache = {} - _current_D = D - - needs_realloc_tensor = (recv_counts > _current_capacities).any() - needs_realloc = needs_realloc_tensor.item() or (D != _current_D) - - if needs_realloc: - new_caps = torch.max(_current_capacities, (recv_counts.float() * 1.2).long()) - new_caps = torch.max(new_caps, torch.tensor(1024, dtype=torch.long, device=device)) - - _current_capacities.copy_(new_caps) - _current_D = D - - my_cap = _current_capacities[dist.get_rank(group)].item() - - buf_idx = symm_mem.empty(my_cap, dtype=dtype_idx, device=device) - hdl_idx = symm_mem.rendezvous(buf_idx, group) - - buf_val = symm_mem.empty(my_cap * D, dtype=dtype_val, device=device) - hdl_val = symm_mem.rendezvous(buf_val, group) - - _symm_cache['idx'] = (buf_idx, hdl_idx, torch.tensor(hdl_idx.buffer_ptrs, dtype=torch.int64, device=device)) - _symm_cache['val'] = (buf_val, hdl_val, torch.tensor(hdl_val.buffer_ptrs, dtype=torch.int64, device=device)) - - return _symm_cache['idx'], _symm_cache['val'] - -@torch.no_grad() -def solution( - idx: torch.Tensor, - value: torch.Tensor, - num_nodes: int, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - if world_size == 1: - return idx, value - - rank = dist.get_rank(group) - - # Pre-compile/load kernel cache, sync globally to ensure it's loaded securely - if rank == 0: - _get_ext() - dist.barrier(group=group) - _get_ext() - - # Calculate partitioned sizes by target bucket - owner = (idx % world_size).long() - send_splits = torch.bincount(owner, minlength=world_size) - - # Allocate a side stream for communication latency-hiding - gather_stream = torch.cuda.Stream() - gather_stream.wait_stream(torch.cuda.current_stream()) - - all_send_splits_flat = torch.empty(world_size * world_size, dtype=torch.long, device=idx.device) - - # Kick off metadata all-gather asynchronously - with torch.cuda.stream(gather_stream): - dist.all_gather_into_tensor(all_send_splits_flat, send_splits, group=group) - - # Simultaneously locally pack index/value via sort while peers metadata propagates - perm = torch.argsort(owner, stable=True) - send_idx = idx[perm] - send_value = value[perm] - - # Wait for completion of parallel all-gather - torch.cuda.current_stream().wait_stream(gather_stream) - - all_send_splits = all_send_splits_flat.view(world_size, world_size) - recv_counts = all_send_splits.sum(dim=0) - my_recv_count = recv_counts[rank].item() - - # Rendezvous fast-path via symm_mem cache bounds - idx_res, val_res = get_symm_buffers(recv_counts, value.shape[1:], idx.dtype, value.dtype, idx.device, group) - buf_idx, hdl_idx, ptrs_idx = idx_res - buf_val, hdl_val, ptrs_val = val_res - - peer_recv_offsets = all_send_splits[:rank, :].sum(dim=0) - send_offsets = torch.empty(world_size + 1, dtype=torch.long, device=idx.device) - send_offsets[0] = 0 - torch.cumsum(send_splits, dim=0, out=send_offsets[1:]) - - # Device-side sync: wait for previous step's peer reads to conclude before overriding buffers - hdl_idx.barrier(channel=0) - hdl_val.barrier(channel=0) - - # Unified custom push logic - coalesced writes target device-side remote pools natively - _get_ext().launch_push( - send_idx, send_value, ptrs_idx, ptrs_val, - send_offsets, peer_recv_offsets, world_size - ) - - # Device-side sync: enforce flush and wait for incoming peer pushes to land - hdl_idx.barrier(channel=1) - hdl_val.barrier(channel=1) - - D = 1 - for d in value.shape[1:]: - D *= d - - out_idx = buf_idx[:my_recv_count].clone() - out_val = buf_val[:my_recv_count * D].view(my_recv_count, *value.shape[1:]).clone() - - return out_idx, out_val \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_cuda.py deleted file mode 100755 index 9eb298a..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/69_gnn_sparse_feature_fetch_projection_cuda.py +++ /dev/null @@ -1,255 +0,0 @@ -""" -Strategy: -1. Eliminate `all_to_all` and `argsort` communication overhead by caching embedding shards in `symm_mem` and leveraging NVLink UVA (peer direct memory access). -2. Fetch queried embeddings directly from remote memory (using a custom CUDA read kernel) natively into the correct query order—bypassing sort/unsort completely. -3. Maximize compute-communication overlap via double-buffering. Queries are chunked: while the current chunk is being projected (GEMM) via Tensor Cores on the default stream, the next chunk's embeddings are actively fetched over NVLink on a background stream. -""" - -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void uva_fetch_kernel( - const int64_t* __restrict__ queries, - const uint64_t* __restrict__ shard_ptrs, - __nv_bfloat16* __restrict__ out, - int64_t num_queries, - int64_t shard_size, - int embed_dim, - int world_size -) { - int64_t q_idx = (int64_t)blockIdx.x * blockDim.y + threadIdx.y; - if (q_idx >= num_queries) return; - - int64_t node_id = queries[q_idx]; - int owner = node_id / shard_size; - if (owner >= world_size) owner = world_size - 1; - int64_t local_id = node_id - owner * shard_size; - - const __nv_bfloat16* src = reinterpret_cast(shard_ptrs[owner]) + local_id * embed_dim; - __nv_bfloat16* dst = out + q_idx * embed_dim; - - int d = threadIdx.x; - int stride = blockDim.x; - - if (embed_dim % 8 == 0) { - int vecs = embed_dim / 8; - const float4* src_v = reinterpret_cast(src); - float4* dst_v = reinterpret_cast(dst); - for (int i = d; i < vecs; i += stride) { - dst_v[i] = src_v[i]; - } - } else if (embed_dim % 4 == 0) { - int vecs = embed_dim / 4; - const float2* src_v = reinterpret_cast(src); - float2* dst_v = reinterpret_cast(dst); - for (int i = d; i < vecs; i += stride) { - dst_v[i] = src_v[i]; - } - } else if (embed_dim % 2 == 0) { - int vecs = embed_dim / 2; - const float* src_v = reinterpret_cast(src); - float* dst_v = reinterpret_cast(dst); - for (int i = d; i < vecs; i += stride) { - dst_v[i] = src_v[i]; - } - } else { - for (int i = d; i < embed_dim; i += stride) { - dst[i] = src[i]; - } - } -} - -void launch_uva_fetch( - torch::Tensor queries, - torch::Tensor shard_ptrs, - torch::Tensor out, - int64_t shard_size, - int world_size -) { - int64_t num_queries = queries.size(0); - if (num_queries == 0) return; - - int embed_dim = out.size(1); - - int vec_size = 1; - if (embed_dim % 8 == 0) vec_size = 8; - else if (embed_dim % 4 == 0) vec_size = 4; - else if (embed_dim % 2 == 0) vec_size = 2; - - int vecs = embed_dim / vec_size; - int tx = vecs; - if (tx > 32) tx = 32; - int ty = 256 / tx; - if (ty == 0) ty = 1; - - dim3 threads(tx, ty); - dim3 blocks((num_queries + ty - 1) / ty); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - uva_fetch_kernel<<>>( - queries.data_ptr(), - reinterpret_cast(shard_ptrs.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), - num_queries, - shard_size, - embed_dim, - world_size - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_uva_fetch", &launch_uva_fetch, "UVA fetch of remote embeddings"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gnn_sparse_fetch_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(shape, dtype, device, group): - key = (shape, dtype, device, group) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor) - _symm_cache[key] = res - return res - -@torch.no_grad() -def solution( - local_embedding_shard: torch.Tensor, - input_node_ids: torch.Tensor, - proj_matrix: torch.Tensor, - num_total_nodes: int, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - shard_size = (num_total_nodes + world_size - 1) // world_size - embed_dim = local_embedding_shard.shape[1] - device = local_embedding_shard.device - - if rank == 0: - _get_ext() - dist.barrier(group=group) - - if local_embedding_shard.dtype != torch.bfloat16: - local_embedding_shard = local_embedding_shard.to(torch.bfloat16) - if proj_matrix.dtype != torch.bfloat16: - proj_matrix = proj_matrix.to(torch.bfloat16) - if input_node_ids.dtype != torch.int64: - input_node_ids = input_node_ids.to(torch.int64) - - input_node_ids = input_node_ids.contiguous().view(-1) - proj_matrix = proj_matrix.contiguous() - - # Wait for trailing tasks from any previous iterations prior to overwriting symm buf - dist.barrier(group=group) - - buf, hdl, ptrs_tensor = _get_symm_state( - local_embedding_shard.shape, - local_embedding_shard.dtype, - device, - group - ) - buf.copy_(local_embedding_shard) - - # Assure all embedding shards are registered/copied completely before UVA fetching begins - dist.barrier(group=group) - - Q = input_node_ids.size(0) - C = 32768 # Batch fetch queries into robust cache-friendly GEMM chunks - num_chunks = (Q + C - 1) // C - - # Fast-path for small query workloads - if num_chunks <= 1: - gathered = torch.empty((Q, embed_dim), dtype=torch.bfloat16, device=device) - if Q > 0: - _get_ext().launch_uva_fetch( - input_node_ids, ptrs_tensor, gathered, shard_size, world_size - ) - return torch.matmul(gathered, proj_matrix) - - # Double-buffering path for overlapped comms / compute - out = torch.empty((Q, proj_matrix.size(1)), dtype=torch.bfloat16, device=device) - bufA = torch.empty((C, embed_dim), dtype=torch.bfloat16, device=device) - bufB = torch.empty((C, embed_dim), dtype=torch.bfloat16, device=device) - - fetch_stream = torch.cuda.Stream(device=device) - comp_stream = torch.cuda.current_stream(device=device) - - fetch_events = [torch.cuda.Event() for _ in range(num_chunks)] - comp_events = [torch.cuda.Event(), torch.cuda.Event()] - - # Signal both buffers as conceptually "ready" prior to start - comp_events[0].record(comp_stream) - comp_events[1].record(comp_stream) - - # Kick off the initial fetch payload - with torch.cuda.stream(fetch_stream): - chunk_queries = input_node_ids[0:C] - if chunk_queries.numel() > 0: - _get_ext().launch_uva_fetch( - chunk_queries, ptrs_tensor, bufA, shard_size, world_size - ) - fetch_events[0].record(fetch_stream) - - for i in range(num_chunks): - start_idx = i * C - end_idx = min(start_idx + C, Q) - current_buf = bufA if (i % 2 == 0) else bufB - current_comp_event_idx = i % 2 - - # 1. Pipeline barrier: await the active memory chunk fetch - comp_stream.wait_event(fetch_events[i]) - - # 2. TensorCore local matrix-multiplication pass - chunk_queries_len = end_idx - start_idx - if chunk_queries_len > 0: - torch.mm(current_buf[:chunk_queries_len], proj_matrix, out=out[start_idx:end_idx]) - - # Mark active buffer free for future fetches - comp_events[current_comp_event_idx].record(comp_stream) - - # 3. Schedule upcoming UVA stream fetch overlapping existing host loop & compute - if i + 1 < num_chunks: - next_start = (i + 1) * C - next_end = min(next_start + C, Q) - next_buf = bufB if (i % 2 == 0) else bufA - next_comp_event_idx = (i + 1) % 2 - - with torch.cuda.stream(fetch_stream): - fetch_stream.wait_event(comp_events[next_comp_event_idx]) - next_chunk_queries = input_node_ids[next_start:next_end] - if next_chunk_queries.numel() > 0: - _get_ext().launch_uva_fetch( - next_chunk_queries, ptrs_tensor, next_buf, shard_size, world_size - ) - fetch_events[i+1].record(fetch_stream) - - # Guard ensuring all asynchronous reads safely complete prior to returning execution - comp_stream.wait_stream(fetch_stream) - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/6_gather_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/6_gather_cuda.py deleted file mode 100755 index e672505..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/6_gather_cuda.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Strategy: -- **Topology-Aware Direct Pull**: On Hopper's fully connected NVSwitch, a multi-step tree gather actually increases total NVLink traffic and latency steps. We instead use a flat 1-step P2P pull where the destination rank directly reads from all peers, minimizing latency and saturating the destination's RX bandwidth. -- **Symmetric Memory Staging**: All ranks stage their input chunk into a lightweight `torch.distributed._symmetric_memory` buffer, avoiding a W-sized allocation on every rank and providing stable device pointers for the destination rank. -- **Custom Vectorized CUDA Kernel**: A custom JIT-compiled kernel on the destination rank uses perfectly aligned memory accesses (up to 128-bit `uint4`) to pull from all peers in a single grid launch. This eliminates host-side `cudaMemcpyAsync` loop bottlenecks and maintains peak SM utilization. -- **Zero Local Copies on Destination**: The destination rank skips staging entirely, directly writing its own input tensor to the output buffer, cutting memory operations on the bottleneck rank to the theoretical minimum. -- **Compute-Communication Overlap**: The schedule perfectly scopes device barriers such that non-destination ranks can safely return and launch independent compute while the destination rank is still busy pulling their data. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Direct Pull Gather Kernel -// --------------------------------------------------------------------------- -// The destination rank launches this kernel to directly pull data from all -// peers' symmetric memory buffers over NVLink into a single contiguous output. -__global__ void direct_pull_gather_kernel( - const uint64_t* __restrict__ ptrs, - uint8_t* __restrict__ out, - int64_t chunk_bytes, - int dst -) { - // blockIdx.y represents the peer rank we are pulling from - int peer = blockIdx.y; - - // The destination rank skips itself as it handles its local copy directly - if (peer == dst) return; - - const uint8_t* src = (const uint8_t*)ptrs[peer]; - uint8_t* dst_ptr = out + peer * chunk_bytes; - - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - int64_t stride = (int64_t)gridDim.x * blockDim.x; - - // Dynamically use the widest memory instructions possible based on alignment - if (chunk_bytes % 16 == 0) { - int64_t chunk_16 = chunk_bytes / 16; - const uint4* src_16 = reinterpret_cast(src); - uint4* dst_16 = reinterpret_cast(dst_ptr); - for (int64_t i = idx; i < chunk_16; i += stride) { - dst_16[i] = src_16[i]; - } - } else if (chunk_bytes % 8 == 0) { - int64_t chunk_8 = chunk_bytes / 8; - const uint2* src_8 = reinterpret_cast(src); - uint2* dst_8 = reinterpret_cast(dst_ptr); - for (int64_t i = idx; i < chunk_8; i += stride) { - dst_8[i] = src_8[i]; - } - } else if (chunk_bytes % 4 == 0) { - int64_t chunk_4 = chunk_bytes / 4; - const uint32_t* src_4 = reinterpret_cast(src); - uint32_t* dst_4 = reinterpret_cast(dst_ptr); - for (int64_t i = idx; i < chunk_4; i += stride) { - dst_4[i] = src_4[i]; - } - } else if (chunk_bytes % 2 == 0) { - int64_t chunk_2 = chunk_bytes / 2; - const uint16_t* src_2 = reinterpret_cast(src); - uint16_t* dst_2 = reinterpret_cast(dst_ptr); - for (int64_t i = idx; i < chunk_2; i += stride) { - dst_2[i] = src_2[i]; - } - } else { - for (int64_t i = idx; i < chunk_bytes; i += stride) { - dst_ptr[i] = src[i]; - } - } -} - -void launch_gather_pull( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t chunk_bytes, - int dst -) { - int world_size = ptrs_tensor.size(0); - const uint64_t* ptrs = (const uint64_t*)ptrs_tensor.data_ptr(); - uint8_t* out_ptr = (uint8_t*)out.data_ptr(); - - // Match the host-side element count to the device-side loop stride - int64_t align_size = 1; - if (chunk_bytes % 16 == 0) align_size = 16; - else if (chunk_bytes % 8 == 0) align_size = 8; - else if (chunk_bytes % 4 == 0) align_size = 4; - else if (chunk_bytes % 2 == 0) align_size = 2; - - int64_t elements = chunk_bytes / align_size; - - int threads = 256; - int blocks_x = (elements + threads - 1) / threads; - if (blocks_x == 0) blocks_x = 1; - - // Cap blocks to prevent grid over-subscription; grid-stride loops handle the rest - if (blocks_x > 1024) blocks_x = 1024; - - // 2D Grid: X maps to data chunks, Y maps to peer ranks - dim3 blocks(blocks_x, world_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - direct_pull_gather_kernel<<>>(ptrs, out_ptr, chunk_bytes, dst); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather_pull", &launch_gather_pull, "Direct pull kernel for device-side gather"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gather_direct_pull_ext", CUDA_SRC) - return _ext - - -_resource_cache = {} - -def _get_resources(shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Allocates and caches symmetric memory buffers to completely bypass - process rendezvous overheads on repeated calls. - """ - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor) - _resource_cache[key] = res - return res - - -@torch.no_grad() -def solution( - tensor: torch.Tensor, - dst: int = 0, -) -> torch.Tensor: - if not dist.is_initialized(): - return tensor - - rank = dist.get_rank() - world_size = dist.get_world_size() - - if rank == 0: - _get_ext() - dist.barrier() - - tensor = tensor.contiguous() - buf, hdl, ptrs_tensor = _get_resources(tensor.shape, tensor.dtype, tensor.device) - - # 1. Non-destination ranks quickly stage their input tensor into symmetric memory. - # The destination rank skips this entirely to save memory operations. - if rank != dst: - buf.copy_(tensor) - - # Synchronization ensures symmetric buffers are ready before the pull kernel fires. - hdl.barrier(channel=0) - - # 2. Destination rank pulls from all symmetric memory staging buffers. - if rank == dst: - out = torch.empty((world_size, *tensor.shape), dtype=tensor.dtype, device=tensor.device) - - # Safely copy own local tensor without invoking NVLink overhead. - out[dst].copy_(tensor) - - # Fire vectorized pull-kernel over NVLink mappings directly into the output tensor. - chunk_bytes = tensor.numel() * tensor.element_size() - _get_ext().launch_gather_pull(ptrs_tensor, out, chunk_bytes, dst) - - # Post-sync protects symmetric buffers from next-call overwrite. - hdl.barrier(channel=1) - - if rank == dst: - return out - else: - return tensor \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_cuda.py deleted file mode 100755 index d72c3f5..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/70_gnn_negative_scoring_cuda.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -Strategy: -1. **Algorithmic Reduction**: The reference globally broadcasts `local_neg_scores` (`P x K` elements), computing identical row-wise rankings redundantly across all GPUs. We replace this with local rank computation, shrinking network traffic by > `K`x. -2. **Device-Side Gather & UVA**: We use symmetric memory (`symm_mem`) and direct UVA load instructions for `gather_rankings`. This skips NCCL all-gather and avoids maintaining multiple large buffers. -3. **Compute-Comm Overlap**: While the CPU computes buffer offsets synchronously (a ~5us operation), the GPU pipelines the custom warp-level reduction kernel directly into the symmetric memory buffer. -4. **Fused Math**: We evaluate PyTorch's elementwise `sigmoid` and sort sequence seamlessly within the custom CUDA kernel using warp intrinsics, dropping multiple intermediate memory roundtrips. -""" - -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// Warp-level kernel: efficiently fuses sigmoid float casting and warp-stride local ranking counting -__global__ void compute_local_rankings_warp_kernel( - const __nv_bfloat16* __restrict__ pos_scores, - const __nv_bfloat16* __restrict__ neg_scores, - int64_t* __restrict__ rankings, - int P, - int K -) { - int i = blockIdx.x * (blockDim.x / 32) + threadIdx.x / 32; - int lane = threadIdx.x % 32; - - if (i < P) { - float pos_val = __bfloat162float(pos_scores[i]); - float pos_sig = 1.0f / (1.0f + expf(-pos_val)); - float pos_sig_cmp = __bfloat162float(__float2bfloat16(pos_sig)); - - int local_count = 0; - for (int j = lane; j < K; j += 32) { - float neg_val = __bfloat162float(neg_scores[i * K + j]); - float neg_sig = 1.0f / (1.0f + expf(-neg_val)); - float neg_sig_cmp = __bfloat162float(__float2bfloat16(neg_sig)); - - // Replicates stable descent sorting logic exactly - if (neg_sig_cmp > pos_sig_cmp) { - local_count++; - } - } - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - local_count += __shfl_down_sync(0xffffffff, local_count, offset); - } - - if (lane == 0) { - rankings[i] = 1 + (int64_t)local_count; - } - } -} - -__global__ void gather_sizes_kernel( - const uint64_t* ptrs, - int64_t* sizes_out, - int world_size -) { - int r = threadIdx.x; - if (r < world_size) { - const int64_t* peer_buf = reinterpret_cast(ptrs[r]); - sizes_out[r] = peer_buf[0]; - } -} - -__global__ void gather_rankings_kernel( - const uint64_t* ptrs, - const int64_t* sizes, - const int64_t* offsets, - int64_t* global_rankings, - int world_size -) { - int r = blockIdx.y; - int64_t size = sizes[r]; - int64_t offset = offsets[r]; - const int64_t* peer_buf = reinterpret_cast(ptrs[r]); - - for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - global_rankings[offset + i] = peer_buf[i]; - } -} - -void compute_local_rankings( - torch::Tensor pos_scores, - torch::Tensor neg_scores, - torch::Tensor rankings, - int P, - int K -) { - int threads = 256; - int warps_per_block = threads / 32; - int blocks = (P + warps_per_block - 1) / warps_per_block; - if (blocks == 0) return; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - compute_local_rankings_warp_kernel<<>>( - reinterpret_cast(pos_scores.data_ptr()), - reinterpret_cast(neg_scores.data_ptr()), - rankings.data_ptr(), - P, - K - ); -} - -void gather_sizes( - torch::Tensor ptrs, - torch::Tensor sizes_out, - int world_size -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_sizes_kernel<<<1, 32, 0, stream>>>( - reinterpret_cast(ptrs.data_ptr()), - sizes_out.data_ptr(), - world_size - ); -} - -void gather_rankings( - torch::Tensor ptrs, - torch::Tensor sizes, - torch::Tensor offsets, - torch::Tensor global_rankings, - int world_size -) { - int threads = 256; - int blocks_x = 256; - dim3 blocks(blocks_x, world_size); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - gather_rankings_kernel<<>>( - reinterpret_cast(ptrs.data_ptr()), - sizes.data_ptr(), - offsets.data_ptr(), - global_rankings.data_ptr(), - world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("compute_local_rankings", &compute_local_rankings, "Compute local rankings"); - m.def("gather_sizes", &gather_sizes, "Gather sizes via UVA"); - m.def("gather_rankings", &gather_rankings, "Gather rankings via UVA"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("gnn_ranking_opt", CUDA_SRC) - return _ext - - -_meta_cache = {} -def get_meta_cache(device, group): - group = group or dist.group.WORLD - if group not in _meta_cache: - meta_buf = symm_mem.empty(1, dtype=torch.int64, device=device) - meta_hdl = symm_mem.rendezvous(meta_buf, group) - ptrs_tensor = torch.tensor(meta_hdl.buffer_ptrs, dtype=torch.int64, device=device) - world_size = dist.get_world_size(group) - sizes_out = torch.empty(world_size, dtype=torch.int64, device=device) - offsets_dev = torch.empty(world_size, dtype=torch.int64, device=device) - _meta_cache[group] = (meta_buf, meta_hdl, ptrs_tensor, sizes_out, offsets_dev) - return _meta_cache[group] - - -_comm_cache = {} -def get_comm_cache(min_capacity, device, group): - group = group or dist.group.WORLD - if group not in _comm_cache: - _comm_cache[group] = {"capacity": 0, "buf": None, "hdl": None, "ptrs": None} - - cache = _comm_cache[group] - if min_capacity > cache["capacity"]: - new_cap = max(min_capacity * 2, 1024) - buf = symm_mem.empty(new_cap, dtype=torch.int64, device=device) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - cache["capacity"] = new_cap - cache["buf"] = buf - cache["hdl"] = hdl - cache["ptrs"] = ptrs - - return cache["buf"], cache["hdl"], cache["ptrs"] - - -@torch.no_grad() -def solution( - local_pos_scores: torch.Tensor, - local_neg_scores: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - device = local_pos_scores.device - - P_r = local_pos_scores.shape[0] - K = local_neg_scores.shape[1] if local_neg_scores.ndim > 1 else 0 - - local_pos_scores = local_pos_scores.contiguous() - local_neg_scores = local_neg_scores.contiguous() - - if world_size == 1: - global_rankings = torch.empty(P_r, dtype=torch.int64, device=device) - if P_r > 0: - _get_ext().compute_local_rankings( - local_pos_scores, - local_neg_scores, - global_rankings, - P_r, - K - ) - return global_rankings - - ext = _get_ext() - - # 1. Share sizes via symmetric memory to avoid NCCL syncs - meta_buf, meta_hdl, meta_ptrs, sizes_out, offsets_dev = get_meta_cache(device, group) - - meta_buf[0] = P_r - meta_hdl.barrier(channel=0) - ext.gather_sizes(meta_ptrs, sizes_out, world_size) - - # Implicitly syncs CPU strictly to calculate buffer sizes & total elements safely - all_sizes = sizes_out.tolist() - total_P = sum(all_sizes) - max_P = max(all_sizes) if all_sizes else 0 - - offsets = [0] * world_size - for i in range(1, world_size): - offsets[i] = offsets[i-1] + all_sizes[i-1] - - offsets_dev.copy_(torch.tensor(offsets, dtype=torch.int64, device=device), non_blocking=True) - - # 2. Extract shared symmetric rankings buffer with dynamic capacity checking - comm_buf, comm_hdl, comm_ptrs = get_comm_cache(max_P, device, group) - - # 3. Queue local computations directly into peer-readable comm_buf - if P_r > 0: - ext.compute_local_rankings( - local_pos_scores, - local_neg_scores, - comm_buf, - P_r, - K - ) - - # 4. Enforce writes locally prior to peer retrieval - comm_hdl.barrier(channel=0) - - # 5. Overlap final continuous gather output - global_rankings = torch.empty(total_P, dtype=torch.int64, device=device) - if total_P > 0: - ext.gather_rankings( - comm_ptrs, - sizes_out, - offsets_dev, - global_rankings, - world_size - ) - - return global_rankings \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_cuda.py deleted file mode 100755 index 72f2ef7..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/71_torchrec_kjt_all2all_cuda.py +++ /dev/null @@ -1,363 +0,0 @@ -from typing import Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include - -__global__ void fused_pull_permute_kernel( - const uint64_t* __restrict__ remote_ptrs, - const int32_t* __restrict__ permuted_offsets, - const int32_t* __restrict__ remote_ranks, - const int32_t* __restrict__ remote_offsets, - int N, - int total_elements, - int element_size, - void* __restrict__ dest_buffer -) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - - for (int i = tid; i < total_elements; i += stride) { - // Binary search to find which segment this element belongs to - int L = 0, R = N - 1; - int idx = 0; - while (L <= R) { - int mid = L + (R - L) / 2; - if (permuted_offsets[mid] <= i) { - idx = mid; - L = mid + 1; - } else { - R = mid - 1; - } - } - - int offset_in_segment = i - permuted_offsets[idx]; - int rank = remote_ranks[idx]; - int r_offset = remote_offsets[idx] + offset_in_segment; - - const char* src_base = reinterpret_cast(remote_ptrs[rank]); - char* dst_base = reinterpret_cast(dest_buffer); - - // Use naturally aligned memory loads for efficient NVLink transfers - if (element_size == 4) { - reinterpret_cast(dst_base)[i] = reinterpret_cast(src_base)[r_offset]; - } else if (element_size == 2) { - reinterpret_cast(dst_base)[i] = reinterpret_cast(src_base)[r_offset]; - } else if (element_size == 8) { - reinterpret_cast(dst_base)[i] = reinterpret_cast(src_base)[r_offset]; - } else if (element_size == 1) { - reinterpret_cast(dst_base)[i] = reinterpret_cast(src_base)[r_offset]; - } - } -} - -void fused_pull_permute( - torch::Tensor remote_ptrs, - torch::Tensor permuted_offsets, - torch::Tensor remote_ranks, - torch::Tensor remote_offsets, - int N, - int total_elements, - int element_size, - torch::Tensor dest_buffer -) { - if (total_elements == 0 || N == 0) return; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - if (blocks > 1024) blocks = 1024; // Limit blocks for grid-stride - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - fused_pull_permute_kernel<<>>( - reinterpret_cast(remote_ptrs.data_ptr()), - permuted_offsets.data_ptr(), - remote_ranks.data_ptr(), - remote_offsets.data_ptr(), - N, - total_elements, - element_size, - dest_buffer.data_ptr() - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fused_pull_permute", &fused_pull_permute, "Fused P2P pull and permute"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_p2p_permute_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_buffer(name, min_size, dtype, device): - global _symm_cache - if name in _symm_cache: - buf, hdl, ptrs = _symm_cache[name] - if buf.numel() >= min_size and buf.dtype == dtype: - return buf, hdl, ptrs - - alloc_size = min_size + 1024 - buf = symm_mem.empty(alloc_size, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=device) - _symm_cache[name] = (buf, hdl, ptrs) - return buf, hdl, ptrs - -def _lengths_per_key_vectorized(lengths: torch.Tensor, stride_per_key: List[int]) -> torch.Tensor: - N = len(stride_per_key) - if N == 0: - return torch.empty(0, dtype=lengths.dtype, device=lengths.device) - strides_tensor = torch.tensor(stride_per_key, dtype=torch.long, device=lengths.device) - indices = torch.repeat_interleave( - torch.arange(N, device=lengths.device, dtype=torch.long), - strides_tensor - ) - res = torch.zeros(N, dtype=lengths.dtype, device=lengths.device) - res.scatter_add_(0, indices, lengths) - return res - -def _sum_by_splits(values: List[int], splits: List[int]) -> List[int]: - out: List[int] = [] - offset = 0 - for split in splits: - out.append(sum(values[offset : offset + split])) - offset += split - return out - -def _get_recat( - local_split: int, - num_splits: int, - stagger: int = 1, - device: Optional[torch.device] = None, - batch_size_per_rank: Optional[List[int]] = None, -) -> Optional[torch.Tensor]: - if local_split == 0: - return None - feature_order = [ - x + num_splits // stagger * y - for x in range(num_splits // stagger) - for y in range(stagger) - ] - if batch_size_per_rank is None: - recat = [ - feature_idx + rank_idx * local_split - for feature_idx in range(local_split) - for rank_idx in feature_order - ] - else: - rank_offsets = [0] - for batch_size in batch_size_per_rank[:-1]: - rank_offsets.append(rank_offsets[-1] + local_split * batch_size) - recat = [ - rank_offsets[rank_idx] + feature_idx * batch_size_per_rank[rank_idx] + b - for feature_idx in range(local_split) - for rank_idx in feature_order - for b in range(batch_size_per_rank[rank_idx]) - ] - return torch.tensor(recat, device=device, dtype=torch.int32) - -@torch.no_grad() -def solution( - lengths: torch.Tensor, - values: torch.Tensor, - key_splits: List[int], - batch_size: int, - pg: Optional[dist.ProcessGroup] = None, - weights: Optional[torch.Tensor] = None, - stride_per_key: Optional[List[int]] = None, - stagger: int = 1, -) -> Dict[str, torch.Tensor]: - pg = pg or dist.group.WORLD - world_size = dist.get_world_size(pg) - rank = dist.get_rank(pg) - device = lengths.device - - if rank == 0: - _get_ext() - dist.barrier(group=pg) - - num_features = sum(key_splits) - variable_stride = stride_per_key is not None - if stride_per_key is None: - stride_per_key = [batch_size] * num_features - - # Completely pure-CUDA metadata preparation to bypass host-device syncs - length_per_key_tensor = _lengths_per_key_vectorized(lengths, stride_per_key) - - key_splits_tensor = torch.tensor(key_splits, dtype=torch.long, device=device) - indices = torch.repeat_interleave( - torch.arange(world_size, device=device, dtype=torch.long), - key_splits_tensor - ) - value_splits_tensor = torch.zeros(world_size, dtype=lengths.dtype, device=device) - value_splits_tensor.scatter_add_(0, indices, length_per_key_tensor) - - length_splits = _sum_by_splits(stride_per_key, key_splits) - length_splits_tensor = torch.tensor(length_splits, dtype=torch.int32, device=device) - - split_tensors = [length_splits_tensor, value_splits_tensor.to(torch.int32)] - if variable_stride: - split_tensors.append(key_splits_tensor.to(torch.int32)) - if weights is not None: - split_tensors.append(value_splits_tensor.to(torch.int32)) - if not variable_stride: - split_tensors.append(torch.full((world_size,), batch_size, dtype=torch.int32, device=device)) - - num_tensors = len(split_tensors) - meta_local = torch.stack(split_tensors, dim=1).flatten() - meta_all_flat = torch.empty(world_size, meta_local.numel(), dtype=torch.int32, device=device) - dist.all_gather_into_tensor(meta_all_flat, meta_local, group=pg) - - # meta_all[S, D, T] = Size sent from rank S to rank D for tensor T - meta_all = meta_all_flat.view(world_size, world_size, num_tensors) - - input_tensors = [lengths, values] - tensor_names = ["lengths", "values"] - if variable_stride: - input_tensors.append(torch.tensor(stride_per_key, dtype=torch.int32, device=device)) - tensor_names.append("strides") - if weights is not None: - input_tensors.append(weights) - tensor_names.append("weights") - - symm_ptrs = [] - for T, (name, tensor) in enumerate(zip(tensor_names, input_tensors)): - max_send_size = meta_all[:, :, T].sum(dim=1).max().item() - buf, hdl, ptrs = _get_symm_buffer(name, max_send_size, tensor.dtype, device) - - local_size = meta_all[rank, :, T].sum().item() - if local_size > 0: - buf[:local_size].copy_(tensor.view(-1)) - symm_ptrs.append(ptrs) - - # Barrier before P2P reading symmetric memory buffers - dist.barrier(group=pg) - - def pull_and_permute(T: int, seg_sizes: torch.Tensor, recat: Optional[torch.Tensor], dtype: torch.dtype): - recv_chunk_sizes = meta_all[:, rank, T] - send_offsets_to_D = meta_all[:, :rank, T].sum(dim=1) - - N = seg_sizes.numel() - total_elements = seg_sizes.sum().item() - dest_buffer = torch.empty(total_elements, dtype=dtype, device=device) - if total_elements == 0: - return dest_buffer - - chunk_offsets = torch.zeros(world_size + 1, dtype=torch.int32, device=device) - chunk_offsets[1:] = torch.cumsum(recv_chunk_sizes, dim=0) - - unpermuted_offsets = torch.zeros(N, dtype=torch.int32, device=device) - if N > 1: - unpermuted_offsets[1:] = torch.cumsum(seg_sizes[:-1], dim=0) - - remote_ranks = torch.bucketize(unpermuted_offsets, chunk_offsets[1:], right=True).to(torch.int32) - chunk_offset = unpermuted_offsets - chunk_offsets[remote_ranks] - remote_offsets = send_offsets_to_D[remote_ranks] + chunk_offset - - if recat is not None: - remote_ranks_permuted = remote_ranks[recat].contiguous() - remote_offsets_permuted = remote_offsets[recat].contiguous() - seg_sizes_permuted = seg_sizes[recat] - else: - remote_ranks_permuted = remote_ranks.contiguous() - remote_offsets_permuted = remote_offsets.contiguous() - seg_sizes_permuted = seg_sizes - - permuted_offsets = torch.zeros(N + 1, dtype=torch.int32, device=device) - permuted_offsets[1:] = torch.cumsum(seg_sizes_permuted, dim=0) - - _get_ext().fused_pull_permute( - symm_ptrs[T], permuted_offsets, remote_ranks_permuted, remote_offsets_permuted, - N, total_elements, dest_buffer.element_size(), dest_buffer - ) - return dest_buffer - - # Pull structural metadata flatly first - active_ranks_lengths = (meta_all[:, rank, 0] > 0).nonzero(as_tuple=True)[0] - recv_lengths_unpermuted = pull_and_permute( - 0, meta_all[active_ranks_lengths, rank, 0], None, lengths.dtype - ) - - if variable_stride: - active_ranks_strides = (meta_all[:, rank, 2] > 0).nonzero(as_tuple=True)[0] - recv_strides_unpermuted = pull_and_permute( - 2, meta_all[active_ranks_strides, rank, 2], None, torch.int32 - ) - - # Derive segment descriptors for payloads using local metadata - local_split = key_splits[rank] - if variable_stride: - recv_strides_list = recv_strides_unpermuted.tolist() - seg_sizes_lengths = recv_strides_unpermuted - seg_sizes_values = _lengths_per_key_vectorized(recv_lengths_unpermuted, recv_strides_list).to(torch.int32) - recat = _get_recat(local_split, world_size, stagger, device=device) - else: - stride_per_rank = meta_all[:, rank, -1].tolist() - single_batch_per_rank = all(s == stride_per_rank[0] for s in stride_per_rank) - - if single_batch_per_rank: - B = stride_per_rank[0] - if B > 0: - N = recv_lengths_unpermuted.numel() // B - seg_sizes_lengths = torch.full((N,), B, dtype=torch.int32, device=device) - lengths_2d = recv_lengths_unpermuted.view(N, B) - seg_sizes_values = lengths_2d.sum(dim=1).to(torch.int32) - else: - seg_sizes_lengths = torch.empty(0, dtype=torch.int32, device=device) - seg_sizes_values = torch.empty(0, dtype=torch.int32, device=device) - recat = _get_recat(local_split, world_size, stagger, device=device) - else: - N = recv_lengths_unpermuted.numel() - seg_sizes_lengths = torch.ones(N, dtype=torch.int32, device=device) - seg_sizes_values = recv_lengths_unpermuted.to(torch.int32) - recat = _get_recat(local_split, world_size, stagger, device=device, batch_size_per_rank=stride_per_rank) - - # Fused Permuted P2P Copies for payload tensors - recv_lengths = pull_and_permute(0, seg_sizes_lengths, recat, lengths.dtype) - recv_values = pull_and_permute(1, seg_sizes_values, recat, values.dtype) - - recv_weights = None - if weights is not None: - T_weights = 3 if variable_stride else 2 - recv_weights = pull_and_permute(T_weights, seg_sizes_values, recat, weights.dtype) - - # Barrier before function return to preserve symm_mem valid lifetimes - dist.barrier(group=pg) - - if variable_stride: - recv_strides_permuted = recv_strides_unpermuted - if recat is not None: - recv_strides_permuted = recv_strides_unpermuted[recat] - stride_per_key_per_rank = recv_strides_permuted.view(world_size, local_split).T - if stagger > 1: - order = torch.arange(world_size, device=device).view(stagger, -1).T.reshape(-1) - stride_per_key_per_rank = stride_per_key_per_rank[:, order] - - result = { - "lengths": recv_lengths, - "values": recv_values, - "stride_per_key_per_rank": stride_per_key_per_rank.to(torch.long), - } - else: - result = { - "lengths": recv_lengths, - "values": recv_values, - "stride": torch.tensor(sum(stride_per_rank), device=device, dtype=torch.long), - "stride_per_rank": torch.tensor(stride_per_rank, device=device, dtype=torch.long), - } - - if recv_weights is not None: - result["weights"] = recv_weights - - return result \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/72_hyena_conv1d_boundary_exchange_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/72_hyena_conv1d_boundary_exchange_cuda.py deleted file mode 100755 index a526a37..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/72_hyena_conv1d_boundary_exchange_cuda.py +++ /dev/null @@ -1,285 +0,0 @@ -""" -Strategy: -1. **Eliminate intermediate reshaping and padding**: The stock PyTorch path performs expensive contiguous copies to carve out boundary chunks, pack them into an overlapped buffer, run `F.conv1d`, and then permute/reshape the output back into zigzag format. We replace this entire sequence with a single fused causal depthwise 1D convolution kernel. -2. **Device-Side Communication (Symmetric Memory & UVA)**: We use `torch.distributed._symmetric_memory` to allocate a dedicated device buffer for the halo bounds (`chunk_a` and `chunk_b`). We use a lightweight kernel to pack the boundary slices. After a blockwise barrier, each rank fetches its required overlapping regions directly from its peers' symmetric memory via direct UVA pointers. -3. **Compute-Communication Overlap**: The causal depthwise convolution kernel is parallelized into blocks over sequence chunks (tiles of size `T=1024`). Only the threads working on the leading boundary of a sequence chunk actually dereference the peer UVA pointer. The latency of these remote loads is seamlessly hidden by the GPU warp scheduler automatically executing warps from the vast majority of other sequence tiles that perform strictly local HBM loads. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Pack the boundary halo overlapping parts (chunk_a and chunk_b) -// into a contiguous symmetric memory buffer. -// --------------------------------------------------------------------------- -__global__ void pack_symm_buf_kernel( - const __nv_bfloat16* __restrict__ x, - __nv_bfloat16* __restrict__ symm_buf, - int B, int H, int S, int K -) { - int64_t total_elements = (int64_t)2 * B * H * (K - 1); - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < total_elements) { - int k_idx = idx % (K - 1); - int64_t temp = idx / (K - 1); - int bh = temp % (B * H); - int chunk = temp / (B * H); - - int x_idx; - if (chunk == 0) { - x_idx = S - (K - 1) + k_idx; - } else { - x_idx = 2 * S - (K - 1) + k_idx; - } - symm_buf[idx] = x[bh * (2 * S) + x_idx]; - } -} - -// --------------------------------------------------------------------------- -// Fused causal depthwise convolution over the zigzag layout. -// Automatically pulls halo padding from peer UVA pointers (ptr_prev_a, ptr_next_b) -// and writes directly into the complex output zigzag permuted shape. -// --------------------------------------------------------------------------- -__global__ void causal_depthwise_conv1d_kernel( - const __nv_bfloat16* __restrict__ x, - const __nv_bfloat16* __restrict__ weight, - const __nv_bfloat16* __restrict__ ptr_prev_a, - const __nv_bfloat16* __restrict__ ptr_next_b, - __nv_bfloat16* __restrict__ out, - int B, int H, int S, int K, int T, int grid_x -) { - extern __shared__ __nv_bfloat16 smem[]; - __nv_bfloat16* smem_in = smem; // Size: T + K - 1 - __nv_bfloat16* smem_w = smem + (T + K - 1); // Size: K - - // Use a 1D grid to circumvent the 65535 limit on gridDim.y - int64_t block_idx = blockIdx.x; - int chunk = block_idx % 2; - int64_t temp = block_idx / 2; - int bh = temp % (B * H); - int tile_idx = temp / (B * H); - - int h = bh % H; - int tid = threadIdx.x; - - int out_start = tile_idx * T; - if (out_start >= S) return; - int out_end = out_start + T; - if (out_end > S) out_end = S; - int out_len = out_end - out_start; - - // Load kernel weights into Shared Memory - for (int i = tid; i < K; i += blockDim.x) { - smem_w[i] = weight[h * K + i]; - } - - // Load input sequence + padding block into Shared Memory - int load_len = out_len + K - 1; - for (int i = tid; i < load_len; i += blockDim.x) { - int logical_idx = out_start + i; - __nv_bfloat16 val; - - if (logical_idx < K - 1) { - if (chunk == 0) { - if (ptr_prev_a != nullptr) { - // Fetch directly from rank - 1 peer via UVA - val = ptr_prev_a[bh * (K - 1) + logical_idx]; - } else { - val = __float2bfloat16(0.0f); - } - } else { - if (ptr_next_b != nullptr) { - // Fetch directly from rank + 1 peer via UVA - val = ptr_next_b[bh * (K - 1) + logical_idx]; - } else { - val = __float2bfloat16(0.0f); - } - } - } else { - // Fetch from local input tensor - int x_idx = logical_idx - (K - 1); - if (chunk == 0) { - val = x[bh * (2 * S) + x_idx]; - } else { - val = x[bh * (2 * S) + S + x_idx]; - } - } - smem_in[i] = val; - } - - __syncthreads(); - - // Compute causal depthwise 1D convolution and scatter seamlessly into the final layout - for (int i = tid; i < out_len; i += blockDim.x) { - float sum = 0.0f; - - #pragma unroll(4) - for (int k = 0; k < K; ++k) { - sum += __bfloat162float(smem_in[i + k]) * __bfloat162float(smem_w[k]); - } - - int out_x_idx = out_start + i; - if (chunk == 0) { - out[bh * (2 * S) + out_x_idx] = __float2bfloat16(sum); - } else { - out[bh * (2 * S) + S + out_x_idx] = __float2bfloat16(sum); - } - } -} - -void launch_pack_symm_buf( - torch::Tensor x, - torch::Tensor symm_buf, - int B, int H, int S, int K -) { - TORCH_CHECK(x.dtype() == torch::kBFloat16, "Must be BF16"); - int64_t total_elements = (int64_t)2 * B * H * (K - 1); - if (total_elements == 0) return; - - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - pack_symm_buf_kernel<<>>( - (const __nv_bfloat16*)x.data_ptr(), - (__nv_bfloat16*)symm_buf.data_ptr(), - B, H, S, K - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -void launch_causal_depthwise_conv1d( - torch::Tensor x, - torch::Tensor weight, - int64_t ptr_prev_a, - int64_t ptr_next_b, - torch::Tensor out, - int B, int H, int S, int K, int T -) { - TORCH_CHECK(x.dtype() == torch::kBFloat16, "Must be BF16"); - int grid_x = (S + T - 1) / T; - int64_t total_blocks = (int64_t)grid_x * B * H * 2; - int threads = 256; - - size_t smem_size = (T + 2 * K - 1) * sizeof(__nv_bfloat16); - - if (smem_size > 49152) { - cudaFuncSetAttribute( - causal_depthwise_conv1d_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size - ); - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - causal_depthwise_conv1d_kernel<<>>( - (const __nv_bfloat16*)x.data_ptr(), - (const __nv_bfloat16*)weight.data_ptr(), - (const __nv_bfloat16*)ptr_prev_a, - (const __nv_bfloat16*)ptr_next_b, - (__nv_bfloat16*)out.data_ptr(), - B, H, S, K, T, grid_x - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pack_symm_buf", &launch_pack_symm_buf, "Pack bounds to symm memory"); - m.def("launch_causal_depthwise_conv1d", &launch_causal_depthwise_conv1d, "Fused halo exchange and grouped depthwise conv1d"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("hyena_conv1d_fused_symm", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(size: int, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (size, dtype, group) - if key in _symm_cache: - return _symm_cache[key] - - buf = symm_mem.empty(size, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - _symm_cache[key] = (buf, hdl) - return buf, hdl - -@torch.no_grad() -def solution( - x: torch.Tensor, - weight: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - """ - Per-rank Hyena causal depthwise conv1d over zigzag CP chunks, leveraging UVA direct memory - access and fusing boundary communication perfectly behind independent local convolutions. - """ - group = group or dist.group.WORLD - group_ranks = dist.get_process_group_ranks(group) - group_rank = dist.get_rank(group) - group_world_size = len(group_ranks) - - batch, hidden, local_seq = x.shape - S = local_seq // 2 - K = weight.shape[-1] - pad_size = K - 1 - - x = x.contiguous() - weight = weight.contiguous() - out = torch.empty_like(x) - - ext = _get_ext() - - ptr_prev_a = 0 - ptr_next_b = 0 - - if pad_size > 0: - # Buffer size logically hosts (chunk_a + chunk_b) tightly packed - symm_size = 2 * batch * hidden * pad_size - buf, hdl = _get_symm_state(symm_size, x.dtype, x.device, group) - - # Stage our local communication bounds into the globally addressable buffer - ext.launch_pack_symm_buf(x, buf, batch, hidden, S, K) - hdl.barrier(channel=0) - - # Map Rank - 1 peer pointer for chunk_a - if group_rank > 0: - prev_g_rank = group_ranks[group_rank - 1] - ptr_prev_a = int(hdl.buffer_ptrs[prev_g_rank]) - - # Map Rank + 1 peer pointer for chunk_b - if group_rank < group_world_size - 1: - next_g_rank = group_ranks[group_rank + 1] - # Offset into peer buffer to specifically target chunk_b - chunk_size_bytes = batch * hidden * pad_size * x.element_size() - ptr_next_b = int(hdl.buffer_ptrs[next_g_rank]) + chunk_size_bytes - else: - # Re-read locally created chunk_a mimicking PyTorch reference clone().contiguous() - ptr_next_b = int(hdl.buffer_ptrs[group_ranks[group_rank]]) - - # Tile block setup for balancing shared memory load. - T = 1024 - - # Fire the fused depthwise convolution kernel which seamlessly streams border inputs - # through the resolved symmetric memory direct pointers into shared memory, and - # scatters final output values accurately avoiding PyTorch permute/reshape. - ext.launch_causal_depthwise_conv1d( - x, weight, ptr_prev_a, ptr_next_b, out, batch, hidden, S, K, T - ) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/73_hyena_forward_cp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/73_hyena_forward_cp_cuda.py deleted file mode 100755 index e69de29..0000000 diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/74_vocab_parallel_cross_entropy_loss_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/74_vocab_parallel_cross_entropy_loss_cuda.py deleted file mode 100755 index 5f84607..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/74_vocab_parallel_cross_entropy_loss_cuda.py +++ /dev/null @@ -1,352 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Tuple -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -struct Vec2; - -template <> -struct Vec2<__nv_bfloat16> { - using type = __nv_bfloat162; - static __device__ __forceinline__ float2 to_float2(type v) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __bfloat1622float2(v); -#else - return {__bfloat162float(v.x), __bfloat162float(v.y)}; -#endif - } - static __device__ __forceinline__ type from_float2(float2 f) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __floats2bfloat162_rn(f.x, f.y); -#else - type v; - v.x = __float2bfloat16(f.x); - v.y = __float2bfloat16(f.y); - return v; -#endif - } - static __device__ __forceinline__ float to_float(__nv_bfloat16 v) { - return __bfloat162float(v); - } - static __device__ __forceinline__ __nv_bfloat16 from_float(float f) { - return __float2bfloat16(f); - } -}; - -template <> -struct Vec2 { - using type = float2; - static __device__ __forceinline__ float2 to_float2(type v) { return v; } - static __device__ __forceinline__ type from_float2(float2 f) { return f; } - static __device__ __forceinline__ float to_float(float v) { return v; } - static __device__ __forceinline__ float from_float(float f) { return f; } -}; - -__device__ __forceinline__ float block_reduce_max(float val, float* shared) { - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - for (int offset = 16; offset > 0; offset /= 2) { - val = max(val, __shfl_down_sync(0xffffffff, val, offset)); - } - if (lane == 0) shared[wid] = val; - __syncthreads(); - val = (threadIdx.x < (blockDim.x + 31) / 32) ? shared[lane] : -1e20f; - for (int offset = 16; offset > 0; offset /= 2) { - val = max(val, __shfl_down_sync(0xffffffff, val, offset)); - } - __syncthreads(); - return val; -} - -__device__ __forceinline__ float block_reduce_sum(float val, float* shared) { - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (lane == 0) shared[wid] = val; - __syncthreads(); - val = (threadIdx.x < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - __syncthreads(); - return val; -} - -template -__global__ void kernel_local_max( - const T* __restrict__ logits, - float* __restrict__ sym_M_local, - int N, int P -) { - using V2 = typename Vec2::type; - __shared__ float shared_reduce[32]; - - for (int i = blockIdx.x; i < N; i += gridDim.x) { - float local_max = -1e20f; - if (P % 2 == 0) { - int P2 = P / 2; - const V2* logits2 = (const V2*)(logits + i * P); - for (int j = threadIdx.x; j < P2; j += blockDim.x) { - float2 fvals = Vec2::to_float2(logits2[j]); - local_max = max(local_max, fvals.x); - local_max = max(local_max, fvals.y); - } - } else { - for (int j = threadIdx.x; j < P; j += blockDim.x) { - float val = Vec2::to_float(logits[i * P + j]); - local_max = max(local_max, val); - } - } - local_max = block_reduce_max(local_max, shared_reduce); - if (threadIdx.x == 0) { - sym_M_local[i] = local_max; - } - } -} - -template -__global__ void kernel_shift_and_local_sums( - T* __restrict__ logits, - const long long* __restrict__ sym_ptrs, - const int64_t* __restrict__ target, - float* __restrict__ sym_S_local, - float* __restrict__ sym_t_local, - int N, int P, int vocab_start, int vocab_end, int world_size -) { - using V2 = typename Vec2::type; - __shared__ float shared_reduce[32]; - __shared__ float shared_M_global; - - for (int i = blockIdx.x; i < N; i += gridDim.x) { - if (threadIdx.x == 0) { - float m = -1e20f; - for (int r = 0; r < world_size; ++r) { - const float* remote_M = (const float*)sym_ptrs[r]; - m = max(m, remote_M[i]); - } - shared_M_global = m; - } - __syncthreads(); - float M_global = shared_M_global; - - int target_idx = target[i] - vocab_start; - float local_sum_exp = 0.0f; - float local_t = 0.0f; - - if (P % 2 == 0) { - int P2 = P / 2; - V2* logits2 = (V2*)(logits + i * P); - for (int j = threadIdx.x; j < P2; j += blockDim.x) { - V2 vals = logits2[j]; - float2 fvals = Vec2::to_float2(vals); - fvals.x -= M_global; - fvals.y -= M_global; - logits2[j] = Vec2::from_float2(fvals); - - local_sum_exp += expf(fvals.x) + expf(fvals.y); - if (j * 2 == target_idx) local_t = fvals.x; - else if (j * 2 + 1 == target_idx) local_t = fvals.y; - } - } else { - for (int j = threadIdx.x; j < P; j += blockDim.x) { - float val = Vec2::to_float(logits[i * P + j]); - val -= M_global; - logits[i * P + j] = Vec2::from_float(val); - local_sum_exp += expf(val); - if (j == target_idx) local_t = val; - } - } - - float total_sum_exp = block_reduce_sum(local_sum_exp, shared_reduce); - float total_t = block_reduce_sum(local_t, shared_reduce); - - if (threadIdx.x == 0) { - sym_S_local[i] = total_sum_exp; - sym_t_local[i] = total_t; - } - __syncthreads(); - } -} - -template -__global__ void kernel_global_sums_and_loss( - const long long* __restrict__ sym_ptrs, - T* __restrict__ loss, - int N, int world_size -) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += gridDim.x * blockDim.x) { - float S_global = 0.0f; - float t_global = 0.0f; - for (int r = 0; r < world_size; ++r) { - const float* remote_S = ((const float*)sym_ptrs[r]) + N; - const float* remote_t = ((const float*)sym_ptrs[r]) + 2 * N; - S_global += remote_S[i]; - t_global += remote_t[i]; - } - float final_loss = logf(S_global) - t_global; - loss[i] = Vec2::from_float(final_loss); - } -} - -void launch_kernel_local_max( - torch::Tensor logits, - int64_t sym_M_local, - int N, int P, int dtype_enum -) { - int threads = 256; - int blocks = std::min(N, 65535); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (dtype_enum == 0) { - kernel_local_max<__nv_bfloat16><<>>( - (__nv_bfloat16*)logits.data_ptr(), (float*)sym_M_local, N, P); - } else { - kernel_local_max<<>>( - logits.data_ptr(), (float*)sym_M_local, N, P); - } -} - -void launch_kernel_shift_and_local_sums( - torch::Tensor logits, - torch::Tensor sym_ptrs, - torch::Tensor target, - int64_t sym_S_local, - int64_t sym_t_local, - int N, int P, int vocab_start, int vocab_end, int world_size, int dtype_enum -) { - int threads = 256; - int blocks = std::min(N, 65535); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const long long* d_sym_ptrs = (const long long*)sym_ptrs.data_ptr(); - const int64_t* d_target = target.data_ptr(); - - if (dtype_enum == 0) { - kernel_shift_and_local_sums<__nv_bfloat16><<>>( - (__nv_bfloat16*)logits.data_ptr(), d_sym_ptrs, d_target, - (float*)sym_S_local, (float*)sym_t_local, N, P, vocab_start, vocab_end, world_size); - } else { - kernel_shift_and_local_sums<<>>( - logits.data_ptr(), d_sym_ptrs, d_target, - (float*)sym_S_local, (float*)sym_t_local, N, P, vocab_start, vocab_end, world_size); - } -} - -void launch_kernel_global_sums_and_loss( - torch::Tensor sym_ptrs, - torch::Tensor loss, - int N, int world_size, int dtype_enum -) { - int threads = 256; - int blocks = (N + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const long long* d_sym_ptrs = (const long long*)sym_ptrs.data_ptr(); - - if (dtype_enum == 0) { - kernel_global_sums_and_loss<__nv_bfloat16><<>>( - d_sym_ptrs, (__nv_bfloat16*)loss.data_ptr(), N, world_size); - } else { - kernel_global_sums_and_loss<<>>( - d_sym_ptrs, loss.data_ptr(), N, world_size); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_kernel_local_max", &launch_kernel_local_max); - m.def("launch_kernel_shift_and_local_sums", &launch_kernel_shift_and_local_sums); - m.def("launch_kernel_global_sums_and_loss", &launch_kernel_global_sums_and_loss); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_vocab_ce_ce", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(N: int, device: torch.device, group): - global _symm_cache - key = (device,) - if key in _symm_cache: - c = _symm_cache[key] - if c["N"] >= N: - return c["buf"], c["hdl"], c["ptrs"] - - buf = symm_mem.empty(3 * N, device=device, dtype=torch.float32) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - _symm_cache[key] = {"N": N, "buf": buf, "hdl": hdl, "ptrs": ptrs} - return buf, hdl, ptrs - -@torch.no_grad() -def solution( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - - assert vocab_parallel_logits.is_contiguous(), "vocab_parallel_logits must be contiguous" - - dtype = vocab_parallel_logits.dtype - if dtype == torch.bfloat16: - dtype_enum = 0 - elif dtype == torch.float32: - dtype_enum = 1 - else: - raise ValueError("Only BF16 and F32 are supported") - - logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.shape[-1]) - target_1d = target.reshape(-1).contiguous() - - N, P = logits_2d.shape - vocab_start = rank * P - vocab_end = vocab_start + P - - # 3xN symmetric staging buffer - buf, hdl, sym_ptrs = _get_symm_state(N, vocab_parallel_logits.device, group) - - sym_M_local = int(hdl.buffer_ptrs[rank]) - sym_S_local = sym_M_local + N * 4 - sym_t_local = sym_M_local + 2 * N * 4 - - # 1. Local block-reduced maximum - _get_ext().launch_kernel_local_max( - logits_2d, sym_M_local, N, P, dtype_enum - ) - - # Fast device-side stream barrier - hdl.barrier(channel=0) - - # 2. Inplace Logits shift + Compute Partial Sum-Exps - _get_ext().launch_kernel_shift_and_local_sums( - logits_2d, sym_ptrs, target_1d, - sym_S_local, sym_t_local, - N, P, vocab_start, vocab_end, world_size, dtype_enum - ) - - # Eager peer synchronization for metrics - hdl.barrier(channel=1) - - # 3. Pull globals & formulate standard Loss - loss_1d = torch.empty_like(target_1d, dtype=dtype) - _get_ext().launch_kernel_global_sums_and_loss( - sym_ptrs, loss_1d, N, world_size, dtype_enum - ) - - return loss_1d.view(target.shape) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/75_fla_kimi_delta_attention_cp_tp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/75_fla_kimi_delta_attention_cp_tp_cuda.py deleted file mode 100755 index 1ec9917..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/75_fla_kimi_delta_attention_cp_tp_cuda.py +++ /dev/null @@ -1,388 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Block-level utilities -// --------------------------------------------------------------------------- - -__inline__ __device__ float warp_reduce_sum(float val) { - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) - val += __shfl_down_sync(0xffffffff, val, offset); - return val; -} - -__inline__ __device__ float block_reduce_sum(float val, float* shared_mem) { - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - val = warp_reduce_sum(val); - if (lane == 0) shared_mem[wid] = val; - __syncthreads(); - - float sum = (threadIdx.x < (blockDim.x + 31) / 32) ? shared_mem[lane] : 0.0f; - sum = warp_reduce_sum(sum); - - if (threadIdx.x == 0) shared_mem[0] = sum; - __syncthreads(); - return shared_mem[0]; -} - -// --------------------------------------------------------------------------- -// KDA CP Forward Kernel -// --------------------------------------------------------------------------- - -__global__ void kda_forward_kernel( - const int64_t* __restrict__ cp_ptrs, // Pointers to cp_buf of each CP rank - const __nv_bfloat16* __restrict__ q_ptr, - const __nv_bfloat16* __restrict__ a_log_ptr, - const __nv_bfloat16* __restrict__ dt_bias_ptr, - __nv_bfloat16* __restrict__ out_ptr, - int B, int T_local, int H, int K, int V, int cp_rank -) { - extern __shared__ float smem[]; - float* S = smem; // [K * V] - float* sh_k_float = smem + K * V; // [K] - float* sh_q_float = sh_k_float + K; // [K] - float* sh_decay = sh_q_float + K; // [K] - float* sh_reduce = sh_decay + K; // [32] - - int b = blockIdx.x / H; - int h = blockIdx.x % H; - int tx = threadIdx.x; - - if (tx < V) { - for (int i = 0; i < K; ++i) { - S[i * V + tx] = 0.0f; - } - } - __syncthreads(); - - float a_scale_val = expf(__bfloat162float(a_log_ptr[h])); - float dt_b = (tx < K) ? __bfloat162float(dt_bias_ptr[h * K + tx]) : 0.0f; - int stride_last = 2 * K + V + 1; - - for (int r = 0; r <= cp_rank; ++r) { - const __nv_bfloat16* peer_buf = (const __nv_bfloat16*)cp_ptrs[r]; - - for (int t = 0; t < T_local; ++t) { - int64_t offset = ((int64_t)(b * T_local + t) * H + h) * stride_last; - const __nv_bfloat16* step_ptr = peer_buf + offset; - - float k_val = (tx < K) ? __bfloat162float(step_ptr[tx]) : 0.0f; - float v_val = (tx < V) ? __bfloat162float(step_ptr[K + tx]) : 0.0f; - float g_val = (tx < K) ? __bfloat162float(step_ptr[K + V + tx]) : 0.0f; - float beta_val = __bfloat162float(step_ptr[K + V + K]); - - float k_sq = (tx < K) ? k_val * k_val : 0.0f; - float norm_sq_k = block_reduce_sum(k_sq, sh_reduce); - float norm_k = sqrtf(norm_sq_k); - if (norm_k < 1e-12f) norm_k = 1e-12f; - if (tx < K) sh_k_float[tx] = k_val / norm_k; - - if (r == cp_rank) { - int64_t q_offset = ((int64_t)(b * T_local + t) * H + h) * K; - float q_val = (tx < K) ? __bfloat162float(q_ptr[q_offset + tx]) : 0.0f; - float q_sq = (tx < K) ? q_val * q_val : 0.0f; - float norm_sq_q = block_reduce_sum(q_sq, sh_reduce); - float norm_q = sqrtf(norm_sq_q); - if (norm_q < 1e-12f) norm_q = 1e-12f; - if (tx < K) sh_q_float[tx] = (q_val / norm_q) * (1.0f / sqrtf((float)K)); - } - - if (tx < K) { - float exponent = a_scale_val * (g_val + dt_b); - float sig = 1.0f / (1.0f + expf(-exponent)); - sh_decay[tx] = expf(-5.0f * sig); - } - float beta_sig = 1.0f / (1.0f + expf(-beta_val)); - - __syncthreads(); - - if (tx < V) { - float proj = 0.0f; - #pragma unroll 4 - for (int i = 0; i < K; ++i) { - proj += sh_k_float[i] * S[i * V + tx]; - } - float update = (v_val - proj) * beta_sig; - - float out_val = 0.0f; - #pragma unroll 4 - for (int i = 0; i < K; ++i) { - float new_s = sh_decay[i] * S[i * V + tx] + sh_k_float[i] * update; - S[i * V + tx] = new_s; - if (r == cp_rank) { - out_val += sh_q_float[i] * new_s; - } - } - - if (r == cp_rank) { - int64_t out_offset = ((int64_t)(b * T_local + t) * H + h) * V; - out_ptr[out_offset + tx] = __float2bfloat16(out_val); - } - } - __syncthreads(); - } - } -} - -void launch_kda_forward( - torch::Tensor cp_ptrs_tensor, - torch::Tensor q, torch::Tensor a_log, torch::Tensor dt_bias, torch::Tensor out, int cp_rank -) { - int B = q.size(0); - int T_local = q.size(1); - int H = q.size(2); - int K = q.size(3); - int V = out.size(3); - - int threads = std::max(K, V); - threads = ((threads + 31) / 32) * 32; - - int blocks = B * H; - int smem_size = (K * V + 3 * K + 32) * sizeof(float); - - cudaFuncSetAttribute(kda_forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 227000); - const int64_t* cp_ptrs = cp_ptrs_tensor.data_ptr(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - kda_forward_kernel<<>>( - cp_ptrs, - (__nv_bfloat16*)q.data_ptr(), - (__nv_bfloat16*)a_log.data_ptr(), - (__nv_bfloat16*)dt_bias.data_ptr(), - (__nv_bfloat16*)out.data_ptr(), - B, T_local, H, K, V, cp_rank - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -// --------------------------------------------------------------------------- -// Multimem TP all-reduce Kernel -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; do { asm volatile("atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 0u); -} -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; do { asm volatile("atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 1u); -} -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; do { asm volatile("atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 0u); -} -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; do { asm volatile("atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); } while (tmp != 1u); -} -__device__ void blockwise_barrier_relaxed(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint32_t* send_addr = reinterpret_cast(signal_pad_ptrs[flat_tid] + block_id * world_size + rank); - uint32_t* wait_addr = reinterpret_cast(signal_pad_ptrs[rank] + block_id * world_size + flat_tid); - send_signal_relaxed(send_addr); wait_signal_relaxed(wait_addr); -} -__device__ void blockwise_barrier_acq_rel(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint32_t* send_addr = reinterpret_cast(signal_pad_ptrs[flat_tid] + block_id * world_size + rank); - uint32_t* wait_addr = reinterpret_cast(signal_pad_ptrs[rank] + block_id * world_size + flat_tid); - send_signal_acq_rel(send_addr); wait_signal_acq_rel(wait_addr); -} - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4(const uint64_t* addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) { - asm volatile("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "l"(addr) : "memory"); -} -__device__ __forceinline__ void multimem_st_bf16x4(const uint64_t* addr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" : : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, int world_size, int rank, int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (numel_128 + world_size - 1) / world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = block_id * block_stride; block_start < numel_per_rank; block_start += num_programs * block_stride) { - const int64_t offsets = block_start + tid; - if (offsets >= numel_per_rank) continue; - uint64_t* ptrs = reinterpret_cast(multicast_base) + (rank * numel_per_rank + offsets) * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, torch::Tensor signal_pad_ptrs_tensor, - int64_t numel, int world_size, int rank, int num_blocks, int block_size, int block_stride -) { - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, d_signal, numel, world_size, rank, block_stride); -} - -__global__ void allreduce_bf16_kernel( - const long long* __restrict__ ptrs, __nv_bfloat16* __restrict__ out, int world_size, int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - sum += __bfloat162float(((const __nv_bfloat16*)ptrs[r])[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -void launch_allreduce(torch::Tensor ptrs_tensor, torch::Tensor out, int64_t n) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - int threads = 512; - int blocks = std::min((int)((n + threads - 1) / threads), 65535); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - allreduce_bf16_kernel<<>>( - d_ptrs, (__nv_bfloat16*)out.data_ptr(), world_size, n); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_kda_forward", &launch_kda_forward, "KDA Forward Context Kern"); - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16, "TP multimem all-reduce"); - m.def("launch_allreduce", &launch_allreduce, "TP peer-pointer all-reduce"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("kda_cp_tp_ext", CUDA_SRC) - return _ext - -_cp_cache = {} -def _get_cp_resources(shape, dtype, device, group): - key = (shape, dtype, device, group) - if key in _cp_cache: - return _cp_cache[key] - buf = symm_mem.empty(shape, device=device, dtype=dtype, group=group) - hdl = symm_mem.rendezvous(buf, group=group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, ptrs_tensor) - _cp_cache[key] = res - return res - -_tp_cache = {} -def _get_tp_resources(shape, dtype, device, group): - key = (shape, dtype, device, group) - if key in _tp_cache: - return _tp_cache[key] - buf = symm_mem.empty(shape, device=device, dtype=dtype, group=group) - hdl = symm_mem.rendezvous(buf, group=group) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, ptrs_tensor) - _tp_cache[key] = res - return res - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = 8 - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - if num_threads < 1024: - block_size = 1 - while block_size < num_threads: block_size *= 2 - num_blocks = 1 - else: - block_size = 1024 - num_blocks = min((num_threads + 1023) // 1024, 4) - return num_blocks, block_size, block_size - -@torch.no_grad() -def solution( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, - beta: torch.Tensor, a_log: torch.Tensor, dt_bias: torch.Tensor, - cp_group: Optional[dist.ProcessGroup] = None, - tp_group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - q, k, v, g = q.contiguous(), k.contiguous(), v.contiguous(), g.contiguous() - beta, a_log, dt_bias = beta.contiguous(), a_log.contiguous(), dt_bias.contiguous() - - assert q.dtype == torch.bfloat16, "Hardware bindings exclusively optimized for bfloat16" - B, T_local, H_local, K = q.shape - V = v.shape[-1] - - cp_group = cp_group or dist.group.WORLD - cp_size = dist.get_world_size(group=cp_group) - cp_rank = dist.get_rank(group=cp_group) - - # 1. CP Gathering via Symmetric Memory - stride_last = 2 * K + V + 1 - cp_buf, cp_hdl, cp_ptrs_tensor = _get_cp_resources( - (B, T_local, H_local, stride_last), q.dtype, q.device, cp_group - ) - - cp_buf[..., :K].copy_(k) - cp_buf[..., K:K+V].copy_(v) - cp_buf[..., K+V:2*K+V].copy_(g) - cp_buf[..., 2*K+V:].copy_(beta.unsqueeze(-1)) - - if cp_size > 1: - cp_hdl.barrier(channel=0) - - # 2. Extract specific TP Resources natively resolving buffer placements - tp_active = tp_group is not None and dist.get_world_size(tp_group) > 1 - if tp_active: - tp_size = dist.get_world_size(tp_group) - tp_rank = dist.get_rank(tp_group) - out_buf, tp_hdl, tp_ptrs_tensor = _get_tp_resources( - (B, T_local, H_local, V), q.dtype, q.device, tp_group - ) - else: - out_buf = torch.empty((B, T_local, H_local, V), dtype=q.dtype, device=q.device) - - # 3. Custom KDA sequential recurrent kernel executing up to exactly `cp_rank` - _get_ext().launch_kda_forward( - cp_ptrs_tensor, q, a_log, dt_bias, out_buf, cp_rank - ) - - # 4. In-switch TP All-reduce - if tp_active: - tp_hdl.barrier(channel=0) - n = out_buf.numel() - - # Condition check for Multimem 16-byte hardware constraints - if n % 8 == 0: - num_blocks, block_size, block_stride = _multimem_launch_config(n, tp_size) - dist.barrier(group=tp_group) # Explicit sync required before multicast manipulation - - _get_ext().launch_multimem_allreduce_bf16( - int(tp_hdl.multicast_ptr), tp_hdl.signal_pad_ptrs_dev, - n // 8, tp_size, tp_rank, num_blocks, block_size, block_stride - ) - return out_buf.clone() - else: - final_out = torch.empty_like(out_buf) - _get_ext().launch_allreduce(tp_ptrs_tensor, final_out, n) - return final_out - - return out_buf \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/76_fla_gated_deltanet_cp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/76_fla_gated_deltanet_cp_cuda.py deleted file mode 100755 index 3fa5f9d..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/76_fla_gated_deltanet_cp_cuda.py +++ /dev/null @@ -1,327 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void delta_recurrent_kernel( - const __nv_bfloat16* __restrict__ q, - const __nv_bfloat16* __restrict__ k, - const __nv_bfloat16* __restrict__ v, - const __nv_bfloat16* __restrict__ gate, - const __nv_bfloat16* __restrict__ beta, - const int64_t* __restrict__ a_log_ptrs, - const int64_t* __restrict__ dt_bias_ptrs, - __nv_bfloat16* __restrict__ out, - float* __restrict__ local_state, - uint32_t* __restrict__ local_flags, - float* __restrict__ peer_state, - uint32_t* __restrict__ peer_flags, - int rank, - int world_size, - int B, int T, int QH, int VH, int K, int V -) { - int b_h = blockIdx.x; - if (b_h >= B * VH) return; - int b = b_h / VH; - int hv = b_h % VH; - int qh = hv / (VH / QH); - int tx = threadIdx.x; - - extern __shared__ float smem[]; - float* smem_state = smem; // size K * V - float* sq = smem + K * V; // size K - float* sk = smem + K * V + K; // size K - float* r_q = smem + K * V + 2 * K; // size blockDim.x - float* r_k = smem + K * V + 2 * K + blockDim.x; // size blockDim.x - - // 1) Pipeline Wait: Wait for rank-1 to signal and write the initial state directly into our memory - if (rank > 0 && local_flags != nullptr && local_state != nullptr) { - if (tx == 0) { - volatile uint32_t* flag_ptr = &local_flags[b_h]; - while (*flag_ptr == 0) { - // busy wait for P2P flag - } - } - __syncthreads(); - for (int i = tx; i < K * V; i += blockDim.x) { - smem_state[i] = local_state[b_h * K * V + i]; - } - } else { - for (int i = tx; i < K * V; i += blockDim.x) { - smem_state[i] = 0.0f; - } - } - __syncthreads(); - - // 2) UVA read of all-gathered 1D tensors (avoid explicit NCCL) - float a_scale, bias; - if (a_log_ptrs != nullptr && dt_bias_ptrs != nullptr) { - int chunk_size = VH / world_size; - int owner = hv / chunk_size; - int local_offset = hv % chunk_size; - const __nv_bfloat16* a_log_owner = (const __nv_bfloat16*)a_log_ptrs[owner]; - const __nv_bfloat16* dt_bias_owner = (const __nv_bfloat16*)dt_bias_ptrs[owner]; - a_scale = expf(__bfloat162float(a_log_owner[local_offset])); - bias = __bfloat162float(dt_bias_owner[local_offset]); - } else { - // Fallback for isolated single-rank execution - const __nv_bfloat16* a_log_ptr = (const __nv_bfloat16*)a_log_ptrs; - const __nv_bfloat16* dt_bias_ptr = (const __nv_bfloat16*)dt_bias_ptrs; - a_scale = expf(__bfloat162float(a_log_ptr[hv])); - bias = __bfloat162float(dt_bias_ptr[hv]); - } - - float scale_q_k = 1.0f / sqrtf((float)K); - - // 3) Process chunk elements with fused norms - for (int t = 0; t < T; ++t) { - float local_q_sq = 0.0f; - float local_k_sq = 0.0f; - for (int i = tx; i < K; i += blockDim.x) { - float q_val = __bfloat162float(q[b * T * QH * K + t * QH * K + qh * K + i]); - float k_val = __bfloat162float(k[b * T * QH * K + t * QH * K + qh * K + i]); - sq[i] = q_val; - sk[i] = k_val; - local_q_sq += q_val * q_val; - local_k_sq += k_val * k_val; - } - - r_q[tx] = local_q_sq; - r_k[tx] = local_k_sq; - __syncthreads(); - - // Warp reduction for PyTorch equivalent F.normalize L2 norm - if (tx == 0) { - float sum_q = 0.0f, sum_k = 0.0f; - for (int i = 0; i < blockDim.x; ++i) { - sum_q += r_q[i]; - sum_k += r_k[i]; - } - float norm_q = sqrtf(sum_q); - float norm_k = sqrtf(sum_k); - r_q[0] = 1.0f / (norm_q < 1e-6f ? 1e-6f : norm_q); - r_k[0] = 1.0f / (norm_k < 1e-6f ? 1e-6f : norm_k); - } - __syncthreads(); - - float q_scale = r_q[0] * scale_q_k; - float k_scale = r_k[0]; - - for (int i = tx; i < K; i += blockDim.x) { - sq[i] *= q_scale; - sk[i] *= k_scale; - } - __syncthreads(); - - float gate_t = __bfloat162float(gate[b * T * VH + t * VH + hv]); - float beta_t = __bfloat162float(beta[b * T * VH + t * VH + hv]); - float sp = (gate_t + bias > 20.0f) ? (gate_t + bias) : logf(1.0f + expf(gate_t + bias)); - float decay_t = expf(-a_scale * sp); - - for (int v_idx = tx; v_idx < V; v_idx += blockDim.x) { - float v_val = __bfloat162float(v[b * T * VH * V + t * VH * V + hv * V + v_idx]); - - float proj = 0.0f; - #pragma unroll 4 - for (int k_idx = 0; k_idx < K; ++k_idx) { - proj += sk[k_idx] * (smem_state[k_idx * V + v_idx] * decay_t); - } - - float upd = (v_val - proj) * beta_t; - - float out_val = 0.0f; - #pragma unroll 4 - for (int k_idx = 0; k_idx < K; ++k_idx) { - float s = smem_state[k_idx * V + v_idx] * decay_t + sk[k_idx] * upd; - smem_state[k_idx * V + v_idx] = s; - out_val += sq[k_idx] * s; - } - - out[b * T * VH * V + t * VH * V + hv * V + v_idx] = __float2bfloat16(out_val); - } - __syncthreads(); - } - - // 4) Pipeline Trigger: Write final state to rank+1 using P2P pointers + sync flag - if (rank < world_size - 1 && peer_state != nullptr && peer_flags != nullptr) { - for (int i = tx; i < K * V; i += blockDim.x) { - peer_state[b_h * K * V + i] = smem_state[i]; - } - __threadfence_system(); // ensure state arrives before flag - __syncthreads(); - if (tx == 0) { - atomicExch(&peer_flags[b_h], 1); - } - } -} - -void launch_delta_recurrent( - torch::Tensor q, - torch::Tensor k, - torch::Tensor v, - torch::Tensor gate, - torch::Tensor beta, - torch::Tensor a_log_ptrs, - torch::Tensor dt_bias_ptrs, - torch::Tensor out, - torch::Tensor local_state, - torch::Tensor local_flags, - int64_t peer_state_ptr, - int64_t peer_flags_ptr, - int rank, - int world_size -) { - int B = q.size(0); - int T = q.size(1); - int QH = q.size(2); - int K = q.size(3); - int VH = v.size(2); - int V = v.size(3); - - int threads = 256; - int blocks = B * VH; - int smem_size = (K * V + 2 * K + 2 * threads) * sizeof(float); - - if (smem_size > 49152) { - // Boost limits up dynamically for H100 - cudaFuncSetAttribute(delta_recurrent_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - float* l_state = local_state.numel() > 0 ? local_state.data_ptr() : nullptr; - uint32_t* l_flags = local_flags.numel() > 0 ? (uint32_t*)local_flags.data_ptr() : nullptr; - float* p_state = peer_state_ptr ? reinterpret_cast(peer_state_ptr) : nullptr; - uint32_t* p_flags = peer_flags_ptr ? reinterpret_cast(peer_flags_ptr) : nullptr; - - const int64_t* a_ptrs = a_log_ptrs.defined() ? a_log_ptrs.data_ptr() : nullptr; - const int64_t* d_ptrs = dt_bias_ptrs.defined() ? dt_bias_ptrs.data_ptr() : nullptr; - - delta_recurrent_kernel<<>>( - (__nv_bfloat16*)q.data_ptr(), - (__nv_bfloat16*)k.data_ptr(), - (__nv_bfloat16*)v.data_ptr(), - (__nv_bfloat16*)gate.data_ptr(), - (__nv_bfloat16*)beta.data_ptr(), - a_ptrs, d_ptrs, - (__nv_bfloat16*)out.data_ptr(), - l_state, l_flags, p_state, p_flags, - rank, world_size, - B, T, QH, VH, K, V - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_delta_recurrent", &launch_delta_recurrent, "DeltaNet CP kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("deltanet_recurrent_cp_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_state(B, VH, K, V, dtype, device, group_id, group): - key = (B, VH, K, V, dtype, device, group_id) - if key in _symm_cache: - return _symm_cache[key] - - # [batch_heads, K, V] buffer acting as receiver mailbox locally on each rank - buf_state = symm_mem.empty((B * VH, K, V), dtype=torch.float32, device=device) - hdl_state = symm_mem.rendezvous(buf_state, group) - - # Flags enabling pipelining without NCCL overhead - buf_flags = symm_mem.empty((B * VH,), dtype=torch.int32, device=device) - hdl_flags = symm_mem.rendezvous(buf_flags, group) - - world_size = dist.get_world_size(group) - chunk_size = VH // world_size - - # Used for UVA all-gather of static parameters without collective overhead on hot path - buf_a = symm_mem.empty((chunk_size,), dtype=dtype, device=device) - hdl_a = symm_mem.rendezvous(buf_a, group) - - buf_dt = symm_mem.empty((chunk_size,), dtype=dtype, device=device) - hdl_dt = symm_mem.rendezvous(buf_dt, group) - - res = (buf_state, hdl_state, buf_flags, hdl_flags, buf_a, hdl_a, buf_dt, hdl_dt) - _symm_cache[key] = res - return res - -@torch.no_grad() -def solution( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - gate: torch.Tensor, - beta: torch.Tensor, - a_log: torch.Tensor, - dt_bias: torch.Tensor, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - - B, T, QH, K = q.shape - VH = v.size(2) - V = v.size(3) - - out = torch.empty((B, T, VH, V), dtype=q.dtype, device=q.device) - - is_dist = dist.is_initialized() - world_size = dist.get_world_size(group) if is_dist else 1 - rank = dist.get_rank(group) if is_dist else 0 - - if world_size > 1: - buf_state, hdl_state, buf_flags, hdl_flags, buf_a, hdl_a, buf_dt, hdl_dt = _get_symm_state( - B, VH, K, V, a_log.dtype, q.device, id(group), group - ) - - # Clear flags locally ensuring we don't accidentally match old triggers - buf_flags.zero_() - hdl_flags.barrier(channel=0) - - # Drop 1D tensors to symmetric memory pool for zero-NCCL UVA all-gathers in the kernel - buf_a.copy_(a_log) - buf_dt.copy_(dt_bias) - hdl_a.barrier(channel=0) - hdl_dt.barrier(channel=0) - - peer_state_ptr = int(hdl_state.buffer_ptrs[rank + 1]) if rank < world_size - 1 else 0 - peer_flags_ptr = int(hdl_flags.buffer_ptrs[rank + 1]) if rank < world_size - 1 else 0 - - a_ptrs = torch.tensor(hdl_a.buffer_ptrs, dtype=torch.int64, device=q.device) - d_ptrs = torch.tensor(hdl_dt.buffer_ptrs, dtype=torch.int64, device=q.device) - - local_state = buf_state - local_flags = buf_flags - else: - peer_state_ptr = 0 - peer_flags_ptr = 0 - a_ptrs = torch.tensor([a_log.data_ptr()], dtype=torch.int64, device=q.device) - d_ptrs = torch.tensor([dt_bias.data_ptr()], dtype=torch.int64, device=q.device) - local_state = torch.empty(0, device=q.device) - local_flags = torch.empty(0, device=q.device) - - _get_ext().launch_delta_recurrent( - q.contiguous(), k.contiguous(), v.contiguous(), - gate.contiguous(), beta.contiguous(), - a_ptrs, d_ptrs, out, - local_state, local_flags, peer_state_ptr, peer_flags_ptr, - rank, world_size - ) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/77_opensora_conv3d_allreduce_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/77_opensora_conv3d_allreduce_cuda.py deleted file mode 100755 index f10210a..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/77_opensora_conv3d_allreduce_cuda.py +++ /dev/null @@ -1,463 +0,0 @@ -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -_CONV3D_NUMEL_LIMIT = 2**31 - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Blockwise barrier for Multimem -// --------------------------------------------------------------------------- - -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 1u); -} - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) - : "l"(addr) - : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier_relaxed( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_relaxed(send_addr); - wait_signal_relaxed(wait_addr); -} - -__device__ void blockwise_barrier_acq_rel( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -__device__ __forceinline__ void multimem_ld_reduce_bf16x4( - const uint64_t* addr, - uint32_t& r0, - uint32_t& r1, - uint32_t& r2, - uint32_t& r3 -) { - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(addr) - : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4( - const uint64_t* addr, - uint32_t x, - uint32_t y, - uint32_t z, - uint32_t w -) { - asm volatile( - "multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" - : - : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) - : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, - int world_size, - int rank, - int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = - (numel_128 + (int64_t)world_size - 1) / (int64_t)world_size; - - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = (int64_t)block_id * (int64_t)block_stride; - block_start < numel_per_rank; - block_start += (int64_t)num_programs * (int64_t)block_stride) - { - const int64_t offsets = block_start + (int64_t)tid; - if (offsets >= numel_per_rank) { - continue; - } - const int64_t idx = (int64_t)rank * numel_per_rank + offsets; - uint64_t* ptrs = reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -// --------------------------------------------------------------------------- -// Fallback Peer-Pointer AllReduce Kernel -// --------------------------------------------------------------------------- - -__global__ void allreduce_bf16_kernel( - const long long* __restrict__ ptrs, - __nv_bfloat16* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __nv_bfloat16* src = (const __nv_bfloat16*)ptrs[r]; - sum += __bfloat162float(src[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -__global__ void allreduce_f32_kernel( - const long long* __restrict__ ptrs, - float* __restrict__ out, - int world_size, - int64_t n -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* src = (const float*)ptrs[r]; - sum += src[idx]; - } - out[idx] = sum; - } -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, - torch::Tensor signal_pad_ptrs_tensor, - int64_t numel_128, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride -) { - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>( - multicast_ptr, - d_signal, - numel_128, - world_size, - rank, - block_stride); -} - -void launch_allreduce( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t n, - int dtype_enum -) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (dtype_enum == 0) { - allreduce_bf16_kernel<<>>( - d_ptrs, (__nv_bfloat16*)out.data_ptr(), world_size, n); - } else { - allreduce_f32_kernel<<>>( - d_ptrs, out.data_ptr(), world_size, n); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16, "Multimem all-reduce on symmetric multicast pointer"); - m.def("launch_allreduce", &launch_allreduce, "Custom P2P all-reduce fallback"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("opensora_conv3d_allreduce_ext", CUDA_SRC) - return _ext - - -WARP_SIZE = 32 -MAX_NUM_BLOCKS = 4 -MAX_BLOCK_SIZE = 1024 -BYTES_PER_THREAD = 16 - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = BYTES_PER_THREAD // 2 # bf16 - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - if num_threads < MAX_BLOCK_SIZE: - block_size = 1 - while block_size < num_threads: - block_size *= 2 - num_blocks = 1 - else: - block_size = MAX_BLOCK_SIZE - num_blocks = min( - (num_threads + MAX_BLOCK_SIZE - 1) // MAX_BLOCK_SIZE, - MAX_NUM_BLOCKS, - ) - return num_blocks, block_size, block_size - - -_resource_cache = {} - -def _get_resources(shape, dtype, device, group): - key = (shape, dtype, device, group) - if key in _resource_cache: - return _resource_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - - out = torch.empty(shape, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, out, ptrs_tensor) - _resource_cache[key] = res - return res - - -def _to_3tuple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]: - return (value, value, value) if isinstance(value, int) else value - - -def _ceil_to_divisible(n: int, dividend: int) -> int: - return math.ceil(dividend / (dividend // n)) - - -def _output_shape( - input_shape: torch.Size, - out_channels: int, - kernel_size: Tuple[int, int, int], - stride: Tuple[int, int, int], - padding: Tuple[int, int, int], - dilation: Tuple[int, int, int], -) -> List[int]: - shape = [input_shape[0], out_channels] - for idx, size in enumerate(input_shape[-3:]): - out = (size + 2 * padding[idx] - dilation[idx] * (kernel_size[idx] - 1) - 1) - shape.append(math.floor(out / stride[idx] + 1)) - return shape - - -def _chunk_count(numel: int, channels: int, limit: int) -> int: - chunks = math.ceil(numel / limit) - return _ceil_to_divisible(chunks, channels) - - -def _channel_chunk_conv3d( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - stride: Tuple[int, int, int], - padding: Tuple[int, int, int], - dilation: Tuple[int, int, int], - groups: int, - numel_limit: int, -) -> torch.Tensor: - out_channels, in_channels = weight.shape[:2] - output_shape = _output_shape( - x.shape, - out_channels, - tuple(weight.shape[2:]), - stride, - padding, - dilation, - ) - in_chunks = _chunk_count(x.numel(), in_channels, numel_limit) - out_chunks = _chunk_count(math.prod(output_shape), out_channels, numel_limit) - if in_chunks == 1 and out_chunks == 1: - return F.conv3d(x, weight, bias, stride, padding, dilation, groups) - - x_chunks = x.chunk(in_chunks, dim=1) - weight_out_chunks = weight.chunk(out_chunks, dim=0) - bias_chunks = bias.chunk(out_chunks) if bias is not None else [None] * out_chunks - outputs: List[torch.Tensor] = [] - for weight_chunk, bias_chunk in zip(weight_out_chunks, bias_chunks): - partial_sum: Optional[torch.Tensor] = None - for x_chunk, w_chunk in zip(x_chunks, weight_chunk.chunk(in_chunks, dim=1)): - partial = F.conv3d( - x_chunk, - w_chunk, - None, - stride, - padding, - dilation, - groups, - ).float() - partial_sum = partial if partial_sum is None else partial_sum + partial - if partial_sum is None: - raise RuntimeError("conv3d chunking produced no partial outputs") - out = partial_sum.to(dtype=x.dtype) - if bias_chunk is not None: - out = out + bias_chunk.view(1, -1, 1, 1, 1) - outputs.append(out) - return torch.cat(outputs, dim=1) - - -@torch.no_grad() -def solution( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - stride: Union[int, Tuple[int, int, int]], - padding: Union[int, Tuple[int, int, int]], - dilation: Union[int, Tuple[int, int, int]], - groups: int = 1, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - - # Compute native cuDNN Conv3D local output - local_out = _channel_chunk_conv3d( - input, - weight, - None, # Handled directly after Custom AllReduce - _to_3tuple(stride), - _to_3tuple(padding), - _to_3tuple(dilation), - groups, - _CONV3D_NUMEL_LIMIT, - ).contiguous() - - n = local_out.numel() - dtype = local_out.dtype - - # Grab caching elements & copy for shared access - buf, hdl, out, ptrs_tensor = _get_resources(local_out.shape, dtype, local_out.device, group) - buf.copy_(local_out) - - if dtype == torch.bfloat16: - numel_per_thread = BYTES_PER_THREAD // local_out.element_size() - if n % numel_per_thread != 0: - hdl.barrier(channel=0) - _get_ext().launch_allreduce(ptrs_tensor, out, n, 0) - reduced_tensor = out - else: - numel_128 = n // numel_per_thread - num_blocks, block_size, block_stride = _multimem_launch_config(n, hdl.world_size) - - # Ensure native copy has completed prior to remote device access - dist.barrier(group=group) - - multicast_ptr = int(hdl.multicast_ptr) - signal_dev = hdl.signal_pad_ptrs_dev - - _get_ext().launch_multimem_allreduce_bf16( - multicast_ptr, - signal_dev, - numel_128, - hdl.world_size, - hdl.rank, - num_blocks, - block_size, - block_stride, - ) - reduced_tensor = buf - else: - hdl.barrier(channel=0) - dtype_enum = 1 if dtype == torch.float32 else 0 - _get_ext().launch_allreduce(ptrs_tensor, out, n, dtype_enum) - reduced_tensor = out - - # Fused element-wise operations with implicit out-of-place allocation - if bias is not None: - return reduced_tensor + bias.view(1, -1, 1, 1, 1) - else: - return reduced_tensor.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/78_magi1_cso_async_attention_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/78_magi1_cso_async_attention_cuda.py deleted file mode 100755 index 92b9d69..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/78_magi1_cso_async_attention_cuda.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Strategy: -- **Device-side P2P Communication**: Bypasses `all_to_all_single` intermediate steps by pushing and pulling data directly to/from peers via `torch.distributed._symmetric_memory` and NVLink pointers. -- **Upfront KV Fetch**: Uses a single custom kernel to asynchronously construct the fully redistributed KV tensor locally from peer memory. -- **Overlap & Pipelining**: Hides sequence gathering and output scattering behind SDPA computation. We double-buffer Queries and use three independent streams (Input Comm, Compute, Output Comm). This ensures `Q_{d+1}` is fetched and `Out_{d-1}` is scattered while SDPA processes `Q_d`. -- **Vectorized Copy**: Employs `uint4` (128-bit) vectorized loads/stores where the inner head dimension allows, maximizing 16-bit (bf16/fp16) memory bandwidth utilization. -""" - -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// KV Gather Kernel -// --------------------------------------------------------------------------- - -__global__ void gather_kv_kernel_uint4( - const int64_t* __restrict__ symm_kv_ptrs, - uint4* __restrict__ local_kv, - int D, int S, int W, int Hkv, int Hd2_vec, int clip_token_nums, int rank -) { - int64_t total_elements = (int64_t)D * clip_token_nums * Hkv * Hd2_vec; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd2_vec; - int h_idx = (idx / Hd2_vec) % Hkv; - int t_idx = idx / (Hd2_vec * Hkv); - - int d = t_idx / clip_token_nums; - int c = t_idx % clip_token_nums; - - int r = c / S; - int seq_idx = d * S + (c % S); - int h_global = rank * Hkv + h_idx; - - const uint4* src_ptr = reinterpret_cast(static_cast(symm_kv_ptrs[r])); - int64_t src_idx = ((int64_t)seq_idx * (W * Hkv) + h_global) * Hd2_vec + hd_idx; - - local_kv[idx] = src_ptr[src_idx]; -} - -__global__ void gather_kv_kernel_scalar( - const int64_t* __restrict__ symm_kv_ptrs, - uint16_t* __restrict__ local_kv, - int D, int S, int W, int Hkv, int Hd2, int clip_token_nums, int rank -) { - int64_t total_elements = (int64_t)D * clip_token_nums * Hkv * Hd2; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd2; - int h_idx = (idx / Hd2) % Hkv; - int t_idx = idx / (Hd2 * Hkv); - - int d = t_idx / clip_token_nums; - int c = t_idx % clip_token_nums; - - int r = c / S; - int seq_idx = d * S + (c % S); - int h_global = rank * Hkv + h_idx; - - const uint16_t* src_ptr = reinterpret_cast(static_cast(symm_kv_ptrs[r])); - int64_t src_idx = ((int64_t)seq_idx * (W * Hkv) + h_global) * Hd2 + hd_idx; - - local_kv[idx] = src_ptr[src_idx]; -} - -// --------------------------------------------------------------------------- -// Query Gather Kernel -// --------------------------------------------------------------------------- - -__global__ void gather_q_kernel_uint4( - const int64_t* __restrict__ symm_q_ptrs, - uint4* __restrict__ local_q, - int d, int S, int W, int Hq, int Hd_vec, int rank -) { - int64_t total_elements = (int64_t)W * S * Hq * Hd_vec; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd_vec; - int h_idx = (idx / Hd_vec) % Hq; - int t_idx = idx / (Hd_vec * Hq); - - int r = t_idx / S; - int seq_idx = d * S + (t_idx % S); - int h_global = rank * Hq + h_idx; - - const uint4* src_ptr = reinterpret_cast(static_cast(symm_q_ptrs[r])); - int64_t src_idx = ((int64_t)seq_idx * (W * Hq) + h_global) * Hd_vec + hd_idx; - - local_q[idx] = src_ptr[src_idx]; -} - -__global__ void gather_q_kernel_scalar( - const int64_t* __restrict__ symm_q_ptrs, - uint16_t* __restrict__ local_q, - int d, int S, int W, int Hq, int Hd, int rank -) { - int64_t total_elements = (int64_t)W * S * Hq * Hd; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd; - int h_idx = (idx / Hd) % Hq; - int t_idx = idx / (Hd * Hq); - - int r = t_idx / S; - int seq_idx = d * S + (t_idx % S); - int h_global = rank * Hq + h_idx; - - const uint16_t* src_ptr = reinterpret_cast(static_cast(symm_q_ptrs[r])); - int64_t src_idx = ((int64_t)seq_idx * (W * Hq) + h_global) * Hd + hd_idx; - - local_q[idx] = src_ptr[src_idx]; -} - -// --------------------------------------------------------------------------- -// Output Scatter Kernel -// --------------------------------------------------------------------------- - -__global__ void scatter_out_kernel_uint4( - const uint4* __restrict__ local_out, - const int64_t* __restrict__ symm_out_ptrs, - int d, int S, int W, int Hq, int Hd_vec, int rank -) { - int64_t total_elements = (int64_t)W * S * Hq * Hd_vec; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd_vec; - int h_idx = (idx / Hd_vec) % Hq; - int t_idx = idx / (Hd_vec * Hq); - - int r = t_idx / S; - int seq_idx = d * S + (t_idx % S); - int h_global = rank * Hq + h_idx; - - uint4* dst_ptr = reinterpret_cast(static_cast(symm_out_ptrs[r])); - int64_t dst_idx = ((int64_t)seq_idx * (W * Hq) + h_global) * Hd_vec + hd_idx; - - dst_ptr[dst_idx] = local_out[idx]; -} - -__global__ void scatter_out_kernel_scalar( - const uint16_t* __restrict__ local_out, - const int64_t* __restrict__ symm_out_ptrs, - int d, int S, int W, int Hq, int Hd, int rank -) { - int64_t total_elements = (int64_t)W * S * Hq * Hd; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; - - int hd_idx = idx % Hd; - int h_idx = (idx / Hd) % Hq; - int t_idx = idx / (Hd * Hq); - - int r = t_idx / S; - int seq_idx = d * S + (t_idx % S); - int h_global = rank * Hq + h_idx; - - uint16_t* dst_ptr = reinterpret_cast(static_cast(symm_out_ptrs[r])); - int64_t dst_idx = ((int64_t)seq_idx * (W * Hq) + h_global) * Hd + hd_idx; - - dst_ptr[dst_idx] = local_out[idx]; -} - -// --------------------------------------------------------------------------- -// Python Bindings -// --------------------------------------------------------------------------- - -void launch_gather_kv( - torch::Tensor symm_kv_ptrs_tensor, - torch::Tensor local_kv, - int D, int S, int W, int Hkv, int Hd2, int clip_token_nums, int rank, - int64_t stream_ptr -) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - const int64_t* ptrs = symm_kv_ptrs_tensor.data_ptr(); - - if (Hd2 % 8 == 0) { - int Hd2_vec = Hd2 / 8; - int64_t total_elements = (int64_t)D * clip_token_nums * Hkv * Hd2_vec; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - gather_kv_kernel_uint4<<>>( - ptrs, reinterpret_cast(local_kv.data_ptr()), - D, S, W, Hkv, Hd2_vec, clip_token_nums, rank); - } else { - int64_t total_elements = (int64_t)D * clip_token_nums * Hkv * Hd2; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - gather_kv_kernel_scalar<<>>( - ptrs, reinterpret_cast(local_kv.data_ptr()), - D, S, W, Hkv, Hd2, clip_token_nums, rank); - } -} - -void launch_gather_q( - torch::Tensor symm_q_ptrs_tensor, - torch::Tensor local_q, - int d, int S, int W, int Hq, int Hd, int rank, - int64_t stream_ptr -) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - const int64_t* ptrs = symm_q_ptrs_tensor.data_ptr(); - - if (Hd % 8 == 0) { - int Hd_vec = Hd / 8; - int64_t total_elements = (int64_t)W * S * Hq * Hd_vec; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - gather_q_kernel_uint4<<>>( - ptrs, reinterpret_cast(local_q.data_ptr()), - d, S, W, Hq, Hd_vec, rank); - } else { - int64_t total_elements = (int64_t)W * S * Hq * Hd; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - gather_q_kernel_scalar<<>>( - ptrs, reinterpret_cast(local_q.data_ptr()), - d, S, W, Hq, Hd, rank); - } -} - -void launch_scatter_out( - torch::Tensor local_out, - torch::Tensor symm_out_ptrs_tensor, - int d, int S, int W, int Hq, int Hd, int rank, - int64_t stream_ptr -) { - cudaStream_t stream = reinterpret_cast(stream_ptr); - const int64_t* ptrs = symm_out_ptrs_tensor.data_ptr(); - - if (Hd % 8 == 0) { - int Hd_vec = Hd / 8; - int64_t total_elements = (int64_t)W * S * Hq * Hd_vec; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - scatter_out_kernel_uint4<<>>( - reinterpret_cast(local_out.data_ptr()), ptrs, - d, S, W, Hq, Hd_vec, rank); - } else { - int64_t total_elements = (int64_t)W * S * Hq * Hd; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - scatter_out_kernel_scalar<<>>( - reinterpret_cast(local_out.data_ptr()), ptrs, - d, S, W, Hq, Hd, rank); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather_kv", &launch_gather_kv); - m.def("launch_gather_q", &launch_gather_q); - m.def("launch_scatter_out", &launch_scatter_out); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("magi_cso_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(q_shape, kv_shape, dtype, device, group): - global _symm_cache - key = (q_shape, kv_shape, dtype, device, id(group)) - if key in _symm_cache: - return _symm_cache[key] - - q_buf = symm_mem.empty(q_shape, dtype=dtype, device=device) - q_hdl = symm_mem.rendezvous(q_buf, group) - - kv_buf = symm_mem.empty(kv_shape, dtype=dtype, device=device) - kv_hdl = symm_mem.rendezvous(kv_buf, group) - - out_buf = symm_mem.empty(q_shape, dtype=dtype, device=device) - out_hdl = symm_mem.rendezvous(out_buf, group) - - q_ptrs = torch.tensor(q_hdl.buffer_ptrs, dtype=torch.int64, device=device) - kv_ptrs = torch.tensor(kv_hdl.buffer_ptrs, dtype=torch.int64, device=device) - out_ptrs = torch.tensor(out_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - res = (q_buf, q_hdl, q_ptrs, kv_buf, kv_hdl, kv_ptrs, out_buf, out_hdl, out_ptrs) - _symm_cache[key] = res - return res - -class PipelineState: - def __init__(self, W, D, S, Hq, Hkv, Hd, clip_token_nums, dtype, device): - self.stream_in = torch.cuda.Stream(device=device) - self.stream_out = torch.cuda.Stream(device=device) - self.ev_kv_ready = torch.cuda.Event() - self.ev_q_ready = [torch.cuda.Event() for _ in range(D)] - self.ev_sdpa_done = [torch.cuda.Event() for _ in range(D)] - self.local_kv = torch.empty((D * clip_token_nums, Hkv, 2 * Hd), dtype=dtype, device=device) - self.local_q_buf = [torch.empty((W * S, Hq, Hd), dtype=dtype, device=device) for _ in range(2)] - -_state_cache = {} -def _get_pipeline_state(W, D, S, Hq, Hkv, Hd, clip_token_nums, dtype, device): - key = (W, D, S, Hq, Hkv, Hd, clip_token_nums, dtype, device) - if key not in _state_cache: - _state_cache[key] = PipelineState(W, D, S, Hq, Hkv, Hd, clip_token_nums, dtype, device) - return _state_cache[key] - -def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - q = q.unsqueeze(0).transpose(1, 2) - k = k.unsqueeze(0).transpose(1, 2) - v = v.unsqueeze(0).transpose(1, 2) - if k.shape[1] < q.shape[1]: - repeat = q.shape[1] // k.shape[1] - k = k.repeat_interleave(repeat, dim=1) - v = v.repeat_interleave(repeat, dim=1) - return F.scaled_dot_product_attention(q, k, v).squeeze(0).transpose(0, 1).contiguous() - - -@torch.no_grad() -def solution( - query: torch.Tensor, - key_value: torch.Tensor, - k_ranges: torch.Tensor, - cp_shuffle_num: int, - clip_token_nums: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - assert query.element_size() == 2, "Kernels designed for 16-bit dtypes (bf16/fp16)" - group = group or dist.group.WORLD - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - - D = cp_shuffle_num - W = world_size - - # Handle implicit repeated heads for KV upfront - if key_value.shape[1] < W and W % key_value.shape[1] == 0: - key_value = key_value.repeat_interleave(W // key_value.shape[1], dim=1) - - tokens, heads_q, Hd = query.shape - _, heads_kv, Hd2 = key_value.shape - - if tokens % D != 0: - raise ValueError("query token count must divide cp_shuffle_num") - if heads_q % W != 0 or heads_kv % W != 0: - raise ValueError("heads must divide evenly across context ranks") - - S = tokens // D - Hq = heads_q // W - Hkv = heads_kv // W - - clip_token_nums = min(int(clip_token_nums or W * S), W * S) - - q_buf, q_hdl, q_ptrs, kv_buf, kv_hdl, kv_ptrs, out_buf, out_hdl, out_ptrs = _get_symm_state( - query.shape, key_value.shape, query.dtype, query.device, group - ) - - # Push inputs into contiguous symmetric memory for fast peer access - q_buf.copy_(query) - kv_buf.copy_(key_value) - q_hdl.barrier(channel=0) - - state = _get_pipeline_state(W, D, S, Hq, Hkv, Hd, clip_token_nums, query.dtype, query.device) - ext = _get_ext() - comp_stream = torch.cuda.current_stream() - - # Launch overlapped communication pipeline - with torch.cuda.stream(state.stream_in): - # 1. Fetch entire sequence KV async - ext.launch_gather_kv( - kv_ptrs, state.local_kv, D, S, W, Hkv, Hd2, clip_token_nums, rank, state.stream_in.cuda_stream - ) - state.ev_kv_ready.record(stream=state.stream_in) - - # 2. Pipeline fetching Query chunks - for d in range(D): - if d >= 2: - # Prevent overwriting Q buffers before SDPA finishes reading - state.ev_sdpa_done[d - 2].wait(stream=state.stream_in) - - ext.launch_gather_q( - q_ptrs, state.local_q_buf[d % 2], d, S, W, Hq, Hd, rank, state.stream_in.cuda_stream - ) - state.ev_q_ready[d].record(stream=state.stream_in) - - local_out_res = [] - - # Execute Compute (SDPA) - for d in range(D): - if d == 0: - state.ev_kv_ready.wait(stream=comp_stream) - state.ev_q_ready[d].wait(stream=comp_stream) - - q = state.local_q_buf[d % 2] - start = int(k_ranges[d, 0]) - end = int(k_ranges[d, 1]) - k = state.local_kv[start:end, :, :Hd] - v = state.local_kv[start:end, :, Hd:] - - # Compute exact SDPA reference, capturing resulting tensor - out = _sdpa(q, k, v) - local_out_res.append(out.contiguous()) - - state.ev_sdpa_done[d].record(stream=comp_stream) - - # Scatter Pipeline (pushes outputs directly into peers' buffers) - with torch.cuda.stream(state.stream_out): - for d in range(D): - state.ev_sdpa_done[d].wait(stream=state.stream_out) - ext.launch_scatter_out( - local_out_res[d], out_ptrs, d, S, W, Hq, Hd, rank, state.stream_out.cuda_stream - ) - - comp_stream.wait_stream(state.stream_out) - - # Prevent early exit before all scattered segments have arrived - q_hdl.barrier(channel=1) - - return out_buf.clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/79_magi1_tile_parallel_vae_decode_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/79_magi1_tile_parallel_vae_decode_cuda.py deleted file mode 100755 index bd59772..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/79_magi1_tile_parallel_vae_decode_cuda.py +++ /dev/null @@ -1,424 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from typing import List, Optional, Tuple - -from utils.cuda_helpers import compile_cuda_extension - - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void pull_decoded_kernel( - __nv_bfloat16* __restrict__ local_flat, - const int64_t* __restrict__ peer_ptrs, - const int* __restrict__ tile_owners, - const int64_t* __restrict__ tile_offsets, - const int64_t* __restrict__ tile_numels, - int total_tiles, - int rank -) { - int threads = blockDim.x * gridDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int tile_idx = 0; tile_idx < total_tiles; tile_idx++) { - int owner = tile_owners[tile_idx]; - if (owner == rank) continue; - - int64_t offset = tile_offsets[tile_idx]; - int64_t numel = tile_numels[tile_idx]; - const __nv_bfloat16* src = (const __nv_bfloat16*)peer_ptrs[owner]; - - for (int64_t i = tid; i < numel; i += threads) { - local_flat[offset + i] = src[offset + i]; - } - } -} - -__global__ void blend_and_assemble_kernel( - const __nv_bfloat16* __restrict__ decoded_flat, - __nv_bfloat16* __restrict__ final_output, - const int64_t* __restrict__ meta, - int total_tiles, - int B, int Total_T, int Total_H, int Total_W -) { - int threads = blockDim.x * gridDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int tile_idx = 0; tile_idx < total_tiles; tile_idx++) { - int64_t dec_t = meta[tile_idx * 19 + 3]; - int64_t dec_h = meta[tile_idx * 19 + 4]; - int64_t dec_w = meta[tile_idx * 19 + 5]; - int64_t kept_t = meta[tile_idx * 19 + 6]; - int64_t kept_h = meta[tile_idx * 19 + 7]; - int64_t kept_w = meta[tile_idx * 19 + 8]; - int64_t out_t = meta[tile_idx * 19 + 9]; - int64_t out_h = meta[tile_idx * 19 + 10]; - int64_t out_w = meta[tile_idx * 19 + 11]; - int64_t extent_t = meta[tile_idx * 19 + 12]; - int64_t extent_h = meta[tile_idx * 19 + 13]; - int64_t extent_w = meta[tile_idx * 19 + 14]; - int64_t prev_t = meta[tile_idx * 19 + 15]; - int64_t prev_h = meta[tile_idx * 19 + 16]; - int64_t prev_w = meta[tile_idx * 19 + 17]; - int64_t flat_offset = meta[tile_idx * 19 + 18]; - - int64_t num_kept = (int64_t)B * 3 * kept_t * kept_h * kept_w; - - for (int64_t i = tid; i < num_kept; i += threads) { - int64_t rem = i; - int64_t w = rem % kept_w; rem /= kept_w; - int64_t h = rem % kept_h; rem /= kept_h; - int64_t t = rem % kept_t; rem /= kept_t; - int64_t c = rem % 3; rem /= 3; - int64_t b = rem; - - int64_t local_idx = (((b * 3 + c) * dec_t + t) * dec_h + h) * dec_w + w; - float val = __bfloat162float(decoded_flat[flat_offset + local_idx]); - - if (prev_t != -1 && t < extent_t) { - int64_t p_offset = meta[prev_t * 19 + 18]; - int64_t p_dec_t = meta[prev_t * 19 + 3]; - int64_t p_dec_h = meta[prev_t * 19 + 4]; - int64_t p_dec_w = meta[prev_t * 19 + 5]; - - int64_t p_t = p_dec_t - extent_t + t; - int64_t p_idx = (((b * 3 + c) * p_dec_t + p_t) * p_dec_h + h) * p_dec_w + w; - float p_val = __bfloat162float(decoded_flat[p_offset + p_idx]); - float ratio = (float)t / (float)extent_t; - val = p_val * (1.0f - ratio) + val * ratio; - } - - if (prev_h != -1 && h < extent_h) { - int64_t p_offset = meta[prev_h * 19 + 18]; - int64_t p_dec_t = meta[prev_h * 19 + 3]; - int64_t p_dec_h = meta[prev_h * 19 + 4]; - int64_t p_dec_w = meta[prev_h * 19 + 5]; - - int64_t p_h = p_dec_h - extent_h + h; - int64_t p_idx = (((b * 3 + c) * p_dec_t + t) * p_dec_h + p_h) * p_dec_w + w; - float p_val = __bfloat162float(decoded_flat[p_offset + p_idx]); - float ratio = (float)h / (float)extent_h; - val = p_val * (1.0f - ratio) + val * ratio; - } - - if (prev_w != -1 && w < extent_w) { - int64_t p_offset = meta[prev_w * 19 + 18]; - int64_t p_dec_t = meta[prev_w * 19 + 3]; - int64_t p_dec_h = meta[prev_w * 19 + 4]; - int64_t p_dec_w = meta[prev_w * 19 + 5]; - - int64_t p_w = p_dec_w - extent_w + w; - int64_t p_idx = (((b * 3 + c) * p_dec_t + t) * p_dec_h + h) * p_dec_w + p_w; - float p_val = __bfloat162float(decoded_flat[p_offset + p_idx]); - float ratio = (float)w / (float)extent_w; - val = p_val * (1.0f - ratio) + val * ratio; - } - - int64_t O_t = out_t + t; - int64_t O_h = out_h + h; - int64_t O_w = out_w + w; - int64_t out_idx = (((b * 3 + c) * Total_T + O_t) * Total_H + O_h) * Total_W + O_w; - final_output[out_idx] = __float2bfloat16(val); - } - } -} - -void launch_pull_decoded( - torch::Tensor local_flat, - torch::Tensor peer_ptrs, - torch::Tensor tile_owners, - torch::Tensor tile_offsets, - torch::Tensor tile_numels, - int total_tiles, - int rank -) { - int threads = 256; - int blocks = 2048; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - pull_decoded_kernel<<>>( - (__nv_bfloat16*)local_flat.data_ptr(), - peer_ptrs.data_ptr(), - tile_owners.data_ptr(), - tile_offsets.data_ptr(), - tile_numels.data_ptr(), - total_tiles, - rank - ); -} - -void launch_blend_and_assemble( - torch::Tensor decoded_flat, - torch::Tensor final_output, - torch::Tensor meta, - int total_tiles, - int B, int Total_T, int Total_H, int Total_W -) { - int threads = 256; - int blocks = 2048; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - blend_and_assemble_kernel<<>>( - (const __nv_bfloat16*)decoded_flat.data_ptr(), - (__nv_bfloat16*)final_output.data_ptr(), - meta.data_ptr(), - total_tiles, - B, Total_T, Total_H, Total_W - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pull_decoded", &launch_pull_decoded); - m.def("launch_blend_and_assemble", &launch_blend_and_assemble); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("magi1_vae_decode_ext", CUDA_SRC) - return _ext - - -def _index_undot(index: int, loop_size: List[int]) -> List[int]: - out: List[int] = [] - for size in reversed(loop_size): - out.append(index % size) - index //= size - return list(reversed(out)) - - -def _index_dot(index: List[int], loop_size: List[int]) -> int: - value = 0 - for dim, size in zip(index, loop_size): - value = value * size + dim - return value - - -def _split_tiles( - tile_numels: List[int], - group: Optional[dist.ProcessGroup], -) -> Tuple[List[int], List[int]]: - if group is None: - tile_indices = list(range(len(tile_numels))) - return tile_indices, tile_indices - - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - sorted_tiles = sorted( - range(len(tile_numels)), - key=lambda idx: tile_numels[idx], - reverse=True, - ) - per_rank = [sorted_tiles[r::world_size] for r in range(world_size)] - global_order = [idx for shard in per_rank for idx in shard] - return per_rank[rank], global_order - - -def _decode_tile(tile: torch.Tensor, spatial_upsample: int, temporal_upsample: int) -> torch.Tensor: - decoded = F.interpolate( - tile.float(), - scale_factor=(temporal_upsample, spatial_upsample, spatial_upsample), - mode="trilinear", - align_corners=False, - ) - if decoded.shape[1] < 3: - repeats = (3 + decoded.shape[1] - 1) // decoded.shape[1] - decoded = decoded.repeat(1, repeats, 1, 1, 1) - return decoded[:, :3].to(torch.bfloat16) - - -@torch.no_grad() -def solution( - z: torch.Tensor, - tile_latent_min_length: int, - tile_latent_min_height: int, - tile_latent_min_width: int, - spatial_tile_overlap_factor: float, - temporal_tile_overlap_factor: float, - spatial_upsample: int, - temporal_upsample: int, - sr_ratio: int = 1, - first_frame_as_image: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - if dist.is_available() and dist.is_initialized(): - group = group or dist.group.WORLD - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - if rank == 0: - _get_ext() - dist.barrier(group=group) - _get_ext() - else: - group = None - world_size = 1 - rank = 0 - _get_ext() - - tile_latent_min_length = tile_latent_min_length + int(first_frame_as_image) - spatial_upsample = spatial_upsample * sr_ratio - stride_h = int(tile_latent_min_height * (1.0 - spatial_tile_overlap_factor)) - stride_w = int(tile_latent_min_width * (1.0 - spatial_tile_overlap_factor)) - stride_t = int(tile_latent_min_length * (1.0 - temporal_tile_overlap_factor)) - if min(stride_t, stride_h, stride_w) <= 0: - raise ValueError("tile overlap factors must leave a positive stride") - - real_t = tile_latent_min_length * temporal_upsample - real_h = tile_latent_min_height * spatial_upsample - real_w = tile_latent_min_width * spatial_upsample - blend_t = int(real_t * temporal_tile_overlap_factor) - blend_h = int(real_h * spatial_tile_overlap_factor) - blend_w = int(real_w * spatial_tile_overlap_factor) - keep_t = real_t - blend_t - keep_h = real_h - blend_h - keep_w = real_w - blend_w - - tiles_t = (z.shape[2] + stride_t - 1) // stride_t - tiles_h = (z.shape[3] + stride_h - 1) // stride_h - tiles_w = (z.shape[4] + stride_w - 1) // stride_w - loop_size = [tiles_t, tiles_h, tiles_w] - total_tiles = tiles_t * tiles_h * tiles_w - B = z.shape[0] - - # Precompute tile shapes, layout, and bounds dynamically - latent_tiles_shapes = [] - for tile_idx in range(total_tiles): - t_idx, h_idx, w_idx = _index_undot(tile_idx, loop_size) - t0 = t_idx * stride_t - h0 = h_idx * stride_h - w0 = w_idx * stride_w - len_t = min(z.shape[2] - t0, tile_latent_min_length) - len_h = min(z.shape[3] - h0, tile_latent_min_height) - len_w = min(z.shape[4] - w0, tile_latent_min_width) - latent_tiles_shapes.append((len_t, len_h, len_w)) - - decoded_shapes = [] - kept_shapes = [] - flat_offsets = [0] * total_tiles - cur_offset = 0 - for tile_idx, (t, h, w) in enumerate(latent_tiles_shapes): - dec_t = t * temporal_upsample - dec_h = h * spatial_upsample - dec_w = w * spatial_upsample - decoded_shapes.append((dec_t, dec_h, dec_w)) - - kept_t = min(dec_t, keep_t) - kept_h = min(dec_h, keep_h) - kept_w = min(dec_w, keep_w) - kept_shapes.append((kept_t, kept_h, kept_w)) - - flat_offsets[tile_idx] = cur_offset - cur_offset += B * 3 * dec_t * dec_h * dec_w - - total_decoded_numel = cur_offset - Total_T = sum(kept_shapes[_index_dot([i, 0, 0], loop_size)][0] for i in range(tiles_t)) - Total_H = sum(kept_shapes[_index_dot([0, i, 0], loop_size)][1] for i in range(tiles_h)) - Total_W = sum(kept_shapes[_index_dot([0, 0, i], loop_size)][2] for i in range(tiles_w)) - - out_t_offsets = [0] * tiles_t - cur = 0 - for i in range(tiles_t): - out_t_offsets[i] = cur - cur += kept_shapes[_index_dot([i, 0, 0], loop_size)][0] - - out_h_offsets = [0] * tiles_h - cur = 0 - for i in range(tiles_h): - out_h_offsets[i] = cur - cur += kept_shapes[_index_dot([0, i, 0], loop_size)][1] - - out_w_offsets = [0] * tiles_w - cur = 0 - for i in range(tiles_w): - out_w_offsets[i] = cur - cur += kept_shapes[_index_dot([0, 0, i], loop_size)][2] - - # Map rank assignments to tiles - latent_numels = [B * z.shape[1] * s[0] * s[1] * s[2] for s in latent_tiles_shapes] - local_indices, _ = _split_tiles(latent_numels, group) - tile_owners = [0] * total_tiles - if world_size > 1: - per_rank, _ = _split_tiles(latent_numels, None) - sorted_tiles = sorted(range(len(latent_numels)), key=lambda idx: latent_numels[idx], reverse=True) - assigned = [sorted_tiles[r::world_size] for r in range(world_size)] - for r in range(world_size): - for idx in assigned[r]: - tile_owners[idx] = r - - # Render meta array for CUDA - meta_list = [] - for tile_idx in range(total_tiles): - t_idx, h_idx, w_idx = _index_undot(tile_idx, loop_size) - dec_t, dec_h, dec_w = decoded_shapes[tile_idx] - kept_t, kept_h, kept_w = kept_shapes[tile_idx] - - extent_t = min(decoded_shapes[_index_dot([t_idx - 1, h_idx, w_idx], loop_size)][0], dec_t, blend_t) if t_idx > 0 else 0 - extent_h = min(decoded_shapes[_index_dot([t_idx, h_idx - 1, w_idx], loop_size)][1], dec_h, blend_h) if h_idx > 0 else 0 - extent_w = min(decoded_shapes[_index_dot([t_idx, h_idx, w_idx - 1], loop_size)][2], dec_w, blend_w) if w_idx > 0 else 0 - - prev_t = _index_dot([t_idx - 1, h_idx, w_idx], loop_size) if t_idx > 0 else -1 - prev_h = _index_dot([t_idx, h_idx - 1, w_idx], loop_size) if h_idx > 0 else -1 - prev_w = _index_dot([t_idx, h_idx, w_idx - 1], loop_size) if w_idx > 0 else -1 - - meta_list.append([ - t_idx, h_idx, w_idx, - dec_t, dec_h, dec_w, - kept_t, kept_h, kept_w, - out_t_offsets[t_idx], out_h_offsets[h_idx], out_w_offsets[w_idx], - extent_t, extent_h, extent_w, - prev_t, prev_h, prev_w, - flat_offsets[tile_idx] - ]) - meta = torch.tensor(meta_list, dtype=torch.int64, device=z.device) - - # Establish memory structures - hdl = None - if world_size > 1: - decoded_all_flat = symm_mem.empty(total_decoded_numel, dtype=torch.bfloat16, device=z.device) - hdl = symm_mem.rendezvous(decoded_all_flat, group=group) - else: - decoded_all_flat = torch.empty(total_decoded_numel, dtype=torch.bfloat16, device=z.device) - - # Decode and fill owned tiles locally - for tile_idx in local_indices: - t_idx, h_idx, w_idx = _index_undot(tile_idx, loop_size) - t0 = t_idx * stride_t - h0 = h_idx * stride_h - w0 = w_idx * stride_w - tile = z[:, :, t0:t0+tile_latent_min_length, h0:h0+tile_latent_min_height, w0:w0+tile_latent_min_width] - - dec = _decode_tile(tile, spatial_upsample, temporal_upsample).contiguous() - offset = flat_offsets[tile_idx] - decoded_all_flat[offset:offset+dec.numel()].copy_(dec.view(-1)) - - # Pull remaining non-owned tiles over NVLink - if world_size > 1: - hdl.barrier(channel=0) - peer_ptrs = torch.tensor(hdl.buffer_ptrs, dtype=torch.int64, device=z.device) - owners_tensor = torch.tensor(tile_owners, dtype=torch.int32, device=z.device) - offsets_tensor = torch.tensor(flat_offsets, dtype=torch.int64, device=z.device) - decoded_numels = [B * 3 * s[0] * s[1] * s[2] for s in decoded_shapes] - numels_tensor = torch.tensor(decoded_numels, dtype=torch.int64, device=z.device) - - _get_ext().launch_pull_decoded( - decoded_all_flat, peer_ptrs, owners_tensor, offsets_tensor, numels_tensor, total_tiles, rank - ) - hdl.barrier(channel=0) - - # Blend and assemble all final chunks completely fusing operations - final_output = torch.empty((B, 3, Total_T, Total_H, Total_W), dtype=torch.bfloat16, device=z.device) - _get_ext().launch_blend_and_assemble( - decoded_all_flat, final_output, meta, total_tiles, B, Total_T, Total_H, Total_W - ) - - if hdl is not None: - hdl.barrier(channel=0) - - return final_output \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_cuda.py deleted file mode 100755 index 0270ff2..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/7_reducescatter_cuda.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Strategy: -1. NVSwitch Multimem Reductions: We use Hopper's `multimem.ld_reduce` to perform the reduce-scatter. Instead of passing data back and forth over the network, each rank writes to its symmetric memory buffer and then individually reads its specific output chunk via a hardware multicast pointer. The NVSwitch performs the reduction transparently on read. -2. Zero Network Over-Fetch: The kernel maps threads only to the exact sub-slice of the global reduction destined for the calling rank. This ensures we naturally perform a reduce-scatter without manually coordinating a scatter-phase or slicing an all-reduced buffer. -3. Stream-Aware Overlap: PyTorch stream semantics are maintained using `hdl.barrier()`. Multi-channel barriers coordinate read/write-visibility over NVLink independently per stream, permitting the device computation and reduction to seamlessly interleave with minimal CPU overhead. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void reducescatter_multimem_bf16_kernel( - uint64_t multicast_base, - __nv_bfloat16* __restrict__ out, - int64_t chunk_numel, - int rank -) { - int64_t total_vecs = chunk_numel / 8; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - - // Grid-stride loop processing 128 bits (8 x bf16) at a time per thread - for (; idx < total_vecs; idx += (int64_t)gridDim.x * blockDim.x) { - int64_t global_idx = (int64_t)rank * chunk_numel + idx * 8; - uint64_t addr = multicast_base + global_idx * 2; - - uint32_t r0, r1, r2, r3; - // Perform an in-switch load-and-reduce from all peers matching this multicast address - asm volatile( - "multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "l"(addr) - : "memory"); - - uint4* out_ptr = reinterpret_cast(out); - out_ptr[idx] = make_uint4(r0, r1, r2, r3); - } -} - -__global__ void reducescatter_fallback_bf16_kernel( - const long long* __restrict__ ptrs, - __nv_bfloat16* __restrict__ out, - int64_t chunk_numel, - int rank, - int world_size -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < chunk_numel; idx += (int64_t)gridDim.x * blockDim.x) { - int64_t global_idx = rank * chunk_numel + idx; - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const __nv_bfloat16* src = (const __nv_bfloat16*)ptrs[r]; - sum += __bfloat162float(src[global_idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -__global__ void reducescatter_fallback_f32_kernel( - const long long* __restrict__ ptrs, - float* __restrict__ out, - int64_t chunk_numel, - int rank, - int world_size -) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < chunk_numel; idx += (int64_t)gridDim.x * blockDim.x) { - int64_t global_idx = rank * chunk_numel + idx; - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* src = (const float*)ptrs[r]; - sum += src[global_idx]; - } - out[idx] = sum; - } -} - -void launch_multimem_reducescatter_bf16( - uint64_t multicast_ptr, - torch::Tensor out, - int64_t chunk_numel, - int rank -) { - int threads = 256; - int64_t total_vecs = chunk_numel / 8; - int blocks = (total_vecs + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - reducescatter_multimem_bf16_kernel<<>>( - multicast_ptr, - reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), - chunk_numel, - rank - ); -} - -void launch_fallback_reducescatter( - torch::Tensor ptrs_tensor, - torch::Tensor out, - int64_t chunk_numel, - int rank, - int dtype_enum -) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - - int threads = 256; - int blocks = (chunk_numel + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - if (blocks == 0) blocks = 1; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (dtype_enum == 0) { - reducescatter_fallback_bf16_kernel<<>>( - d_ptrs, reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), chunk_numel, rank, world_size - ); - } else { - reducescatter_fallback_f32_kernel<<>>( - d_ptrs, out.data_ptr(), chunk_numel, rank, world_size - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_reducescatter_bf16", &launch_multimem_reducescatter_bf16, "Multimem reduce-scatter BF16"); - m.def("launch_fallback_reducescatter", &launch_fallback_reducescatter, "Fallback reduce-scatter via UVA"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("reducescatter_symm_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(shape, dtype, device): - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - # Pre-allocate buffer via Symmetric Memory to keep working sets in UVA address spaces - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution(tensor: torch.Tensor) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - - world_size = dist.get_world_size() - rank = dist.get_rank() - - assert tensor.shape[0] % world_size == 0, \ - f"First dimension ({tensor.shape[0]}) must be divisible by world_size ({world_size})" - - input_tensor = tensor.contiguous() - N = input_tensor.numel() - chunk_size = input_tensor.shape[0] // world_size - - # Fast paths for zero element arrays - if N == 0: - out_shape = (chunk_size,) + input_tensor.shape[1:] - return torch.empty(out_shape, dtype=input_tensor.dtype, device=input_tensor.device) - - chunk_numel = N // world_size - buf, hdl, ptrs_tensor = _get_resources(input_tensor.shape, input_tensor.dtype, input_tensor.device) - - # Coordinate symmetric streams: ensure no peers are overwriting the buffer locally or via UVA - hdl.barrier(channel=0) - buf.copy_(input_tensor) - - # Ensure all ranks have mapped their latest slices into symmetric memory - hdl.barrier(channel=1) - - out_shape = (chunk_size,) + input_tensor.shape[1:] - out = torch.empty(out_shape, dtype=input_tensor.dtype, device=input_tensor.device) - - multicast_ptr = getattr(hdl, "multicast_ptr", 0) - use_multimem = ( - input_tensor.dtype == torch.bfloat16 - and chunk_numel % 8 == 0 - and multicast_ptr != 0 - ) - - # Perform chunked local reduction with Hopper Multimem or robust fallback loop natively on-device - if use_multimem: - _get_ext().launch_multimem_reducescatter_bf16(multicast_ptr, out, chunk_numel, rank) - else: - dtype_enum = 0 if input_tensor.dtype == torch.bfloat16 else 1 - _get_ext().launch_fallback_reducescatter(ptrs_tensor, out, chunk_numel, rank, dtype_enum) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/80_dinov2_distributed_knn_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/80_dinov2_distributed_knn_cuda.py deleted file mode 100755 index 89350de..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/80_dinov2_distributed_knn_cuda.py +++ /dev/null @@ -1,249 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Tuple - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include - -template -__global__ void gather_peer_queries_kernel( - const scalar_t* __restrict__ peer_ptr, - scalar_t* __restrict__ local_buf, - int64_t numel -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = idx; i < numel; i += blockDim.x * gridDim.x) { - local_buf[i] = peer_ptr[i]; - } -} - -template -__global__ void scatter_peer_topk_kernel( - const scalar_t* __restrict__ local_sims, - const int64_t* __restrict__ local_labels, - scalar_t* __restrict__ peer_sims_buf, - int64_t* __restrict__ peer_labels_buf, - int64_t Q_peer, - int64_t K, - int my_rank, - int64_t max_Q -) { - int64_t numel = Q_peer * K; - int64_t threads = blockDim.x * gridDim.x; - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - - for (int64_t i = idx; i < numel; i += threads) { - int64_t row = i / K; - int64_t col = i % K; - - int64_t dst_idx = (my_rank * max_Q + row) * K + col; - peer_sims_buf[dst_idx] = local_sims[i]; - peer_labels_buf[dst_idx] = local_labels[i]; - } -} - -void gather_peer_queries( - int64_t peer_ptr_val, - torch::Tensor local_buf, - int64_t numel -) { - if (numel == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int threads = 256; - int blocks = std::min((int64_t)65535, (numel + threads - 1) / threads); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, local_buf.scalar_type(), "gather_peer", [&] { - const scalar_t* peer_ptr = reinterpret_cast(peer_ptr_val); - gather_peer_queries_kernel<<>>( - peer_ptr, - local_buf.data_ptr(), - numel - ); - }); -} - -void scatter_peer_topk( - torch::Tensor local_sims, - torch::Tensor local_labels, - int64_t peer_sim_ptr_val, - int64_t peer_label_ptr_val, - int64_t Q_peer, - int64_t K, - int my_rank, - int64_t max_Q -) { - if (Q_peer == 0 || K == 0) return; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - int64_t numel = Q_peer * K; - int threads = 256; - int blocks = std::min((int64_t)65535, (numel + threads - 1) / threads); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, local_sims.scalar_type(), "scatter_peer", [&] { - scalar_t* peer_sims_buf = reinterpret_cast(peer_sim_ptr_val); - int64_t* peer_labels_buf = reinterpret_cast(peer_label_ptr_val); - - scatter_peer_topk_kernel<<>>( - local_sims.data_ptr(), - local_labels.data_ptr(), - peer_sims_buf, - peer_labels_buf, - Q_peer, - K, - my_rank, - max_Q - ); - }); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gather_peer_queries", &gather_peer_queries, "Gather peer queries"); - m.def("scatter_peer_topk", &scatter_peer_topk, "Scatter peer topk"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("dinov2_knn_overlap", CUDA_SRC) - return _ext - -_symm_cache = None -def _get_symm_state(max_Q, D, max_k, w, dtype_q, dtype_l, device, group): - global _symm_cache - if _symm_cache is not None: - c = _symm_cache - if (c["max_Q"] >= max_Q and c["D"] == D and - c["max_k"] == max_k and c["w"] == w and - c["dtype_q"] == dtype_q and c["dtype_l"] == dtype_l and - c["group"] == group): - return c["sq"], c["hq"], c["ss"], c["hs"], c["sl"], c["hl"], c["max_Q"] - - alloc_max_Q = max(max_Q, 1) - sq = symm_mem.empty((alloc_max_Q, D), dtype=dtype_q, device=device) - hq = symm_mem.rendezvous(sq, group) - - ss = symm_mem.empty((w, alloc_max_Q, max_k), dtype=dtype_q, device=device) - hs = symm_mem.rendezvous(ss, group) - - sl = symm_mem.empty((w, alloc_max_Q, max_k), dtype=dtype_l, device=device) - hl = symm_mem.rendezvous(sl, group) - - _symm_cache = { - "max_Q": alloc_max_Q, "D": D, "max_k": max_k, "w": w, - "dtype_q": dtype_q, "dtype_l": dtype_l, "group": group, - "sq": sq, "hq": hq, "ss": ss, "hs": hs, "sl": sl, "hl": hl - } - return sq, hq, ss, hs, sl, hl, alloc_max_Q - - -_streams = None -def _get_streams(world_size): - global _streams - if _streams is None or len(_streams) < world_size: - _streams = [torch.cuda.Stream() for _ in range(world_size)] - return _streams[:world_size] - - -@torch.no_grad() -def solution( - test_features_rank: torch.Tensor, - train_features_rank_T: torch.Tensor, - train_labels_rank: torch.Tensor, - max_k: int, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - - group = group or dist.group.WORLD - rank = dist.get_rank(group=group) - world_size = dist.get_world_size(group=group) - device = test_features_rank.device - - if max_k > train_features_rank_T.shape[1]: - raise ValueError("max_k must not exceed the local train shard size") - - # Serialize compilation so only rank 0 builds - if rank == 0: - _get_ext() - dist.barrier(group=group) - - Q_r = test_features_rank.shape[0] - D = test_features_rank.shape[1] - - # Quick All-gather to discover peers' query sizes - q_tensor = torch.tensor([Q_r], dtype=torch.int64, device=device) - q_sizes_list = [torch.empty_like(q_tensor) for _ in range(world_size)] - dist.all_gather(q_sizes_list, q_tensor, group=group) - q_sizes = torch.cat(q_sizes_list).cpu().tolist() - max_Q = max(q_sizes) - - # Allocate / Reuse symmetric memory - sq, hq, ss, hs, sl, hl, alloc_max_Q = _get_symm_state( - max_Q, D, max_k, world_size, - test_features_rank.dtype, train_labels_rank.dtype, - device, group - ) - - ptr_q = hq.buffer_ptrs - ptr_s = hs.buffer_ptrs - ptr_l = hl.buffer_ptrs - - # Expose local queries into symmetric memory - if Q_r > 0: - sq[:Q_r, :].copy_(test_features_rank) - dist.barrier(group=group) - - current_stream = torch.cuda.current_stream() - streams = _get_streams(world_size) - - # Launch fully pipelined peer GEMMs and Top-K using separate streams - for peer in range(world_size): - Q_peer = q_sizes[peer] - stream = streams[peer] - stream.wait_stream(current_stream) - - with torch.cuda.stream(stream): - if Q_peer > 0: - local_peer_queries = torch.empty((Q_peer, D), dtype=test_features_rank.dtype, device=device) - - _get_ext().gather_peer_queries(int(ptr_q[peer]), local_peer_queries, Q_peer * D) - peer_sims = torch.matmul(local_peer_queries, train_features_rank_T) - - peer_topk_sims, indices = peer_sims.topk(max_k, dim=1, largest=True, sorted=True) - peer_topk_labels = torch.gather(train_labels_rank.expand(Q_peer, -1), 1, indices) - - _get_ext().scatter_peer_topk( - peer_topk_sims, peer_topk_labels, - int(ptr_s[peer]), int(ptr_l[peer]), - Q_peer, max_k, rank, alloc_max_Q - ) - - # Re-join concurrent streams before final Top-K - for stream in streams[:world_size]: - current_stream.wait_stream(stream) - - dist.barrier(group=group) - - # Final Merge phase: Top-K across gathered metrics - if Q_r == 0: - return torch.empty((0, max_k), dtype=test_features_rank.dtype, device=device), \ - torch.empty((0, max_k), dtype=train_labels_rank.dtype, device=device) - - # Sub-slice perfectly bound items written by peers into our memory - valid_sims = ss[:, :Q_r, :] - valid_labels = sl[:, :Q_r, :] - - # Concatenate W sets of Top-K outcomes logically equivalent to all peers' contributions - valid_sims = valid_sims.permute(1, 0, 2).reshape(Q_r, world_size * max_k) - valid_labels = valid_labels.permute(1, 0, 2).reshape(Q_r, world_size * max_k) - - final_topk_sims, final_indices = valid_sims.topk(max_k, dim=1, largest=True, sorted=True) - final_topk_labels = torch.gather(valid_labels, 1, final_indices) - - return final_topk_sims, final_topk_labels \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/81_dinov2_distributed_sinkhorn_knopp_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/81_dinov2_distributed_sinkhorn_knopp_cuda.py deleted file mode 100755 index a47fb7b..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/81_dinov2_distributed_sinkhorn_knopp_cuda.py +++ /dev/null @@ -1,404 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -template -__device__ __forceinline__ float to_float(T x) { return static_cast(x); } - -template <> -__device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } - -// Compute initial row sums and pack the scalar total_batch into the end of the buffer -template -__global__ void compute_U_init_kernel( - const T* __restrict__ T_mat, - float* __restrict__ local_U, - int B, int K, float tau_inv, - const float* __restrict__ n_masked_patches -) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < K) { - float sum = 0.0f; - for (int j = 0; j < B; ++j) { - float val = to_float(T_mat[j * K + i]); - sum += expf(val * tau_inv); - } - local_U[i] = sum * (float)K; - } - // Thread 0 seamlessly packs the local masked patches for fused all-reduce - if (i == 0 && n_masked_patches != nullptr) { - local_U[K] = *n_masked_patches; - } -} - -// Compute column scales V using blocked warp reductions -template -__global__ void compute_V_kernel( - const T* __restrict__ T_mat, - const float* __restrict__ U, - float* __restrict__ V, - int B, int K, float tau_inv, - const float* __restrict__ total_batch_ptr, int batch_offset -) { - int j = blockIdx.x; - float total_batch = total_batch_ptr[batch_offset]; - - float sum = 0.0f; - for (int i = threadIdx.x; i < K; i += blockDim.x) { - float val = to_float(T_mat[j * K + i]); - sum += expf(val * tau_inv) / U[i]; - } - - static __shared__ float shared[32]; - int lane = threadIdx.x % 32; - int warpId = threadIdx.x / 32; - - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - - if (lane == 0) shared[warpId] = sum; - __syncthreads(); - - if (warpId == 0) { - sum = (lane < (blockDim.x / 32)) ? shared[lane] : 0.0f; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - if (lane == 0) { - V[j] = sum * total_batch; - } - } -} - -// Update row scales U using coalesced global reads -template -__global__ void compute_U_kernel( - const T* __restrict__ T_mat, - const float* __restrict__ V, - float* __restrict__ local_U, - int B, int K, float tau_inv -) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < K) { - float sum = 0.0f; - for (int j = 0; j < B; ++j) { - float val = to_float(T_mat[j * K + i]); - sum += expf(val * tau_inv) / V[j]; - } - local_U[i] = sum * (float)K; - } -} - -// Resolve final matrix out-of-place -template -__global__ void compute_Final_kernel( - const T* __restrict__ T_mat, - const float* __restrict__ U, - const float* __restrict__ V, - float* __restrict__ Out, - int B, int K, float tau_inv, - const float* __restrict__ total_batch_ptr, int batch_offset -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < B * K) { - int j = idx / K; - int i = idx % K; - float val = to_float(T_mat[idx]); - float total_batch = total_batch_ptr[batch_offset]; - Out[idx] = expf(val * tau_inv) * total_batch / (U[i] * V[j]); - } -} - -// --------------------------------------------------------------------------- -// Acquire-Release Device Barrier Logic -// --------------------------------------------------------------------------- -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile( - "atom.global.acquire.sys.cas.b32 %0, [%1], 1, 0;" - : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier( - const uint64_t* __restrict__ signal_pad_ptrs, - uint64_t block_id, - int rank, - int world_size -) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) { - return; - } - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast( - remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast( - local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -__global__ void allreduce_f32_kernel( - const uint64_t* __restrict__ ptrs, - float* __restrict__ out, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t n, - int world_size, - int rank, - int64_t iteration_offset -) { - const uint64_t base_block_id = static_cast(blockIdx.x) + iteration_offset; - - __syncthreads(); // Ensure peers sync locally - blockwise_barrier(signal_pad_ptrs, base_block_id, rank, world_size); - __syncthreads(); // Wait for barrier subset to unlock blocks - - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* src = reinterpret_cast(ptrs[r]); - sum += src[idx]; - } - out[idx] = sum; - } - - // Fence to avoid peer buffer overwrite before collective read completes - __syncthreads(); - blockwise_barrier(signal_pad_ptrs, base_block_id + gridDim.x, rank, world_size); -} - -// --------------------------------------------------------------------------- -// Launchers -// --------------------------------------------------------------------------- -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -void launch_compute_U_init( - torch::Tensor T_mat, torch::Tensor local_U, - int B, int K, float tau_inv, torch::Tensor n_masked_patches -) { - CHECK_INPUT(T_mat); CHECK_INPUT(local_U); CHECK_INPUT(n_masked_patches); - int threads = 256; - int blocks = (K + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const float* n_ptr = n_masked_patches.data_ptr(); - - if (T_mat.dtype() == torch::kBFloat16) { - compute_U_init_kernel<<>>( - reinterpret_cast(T_mat.data_ptr()), - local_U.data_ptr(), B, K, tau_inv, n_ptr); - } else { - compute_U_init_kernel<<>>( - T_mat.data_ptr(), local_U.data_ptr(), B, K, tau_inv, n_ptr); - } -} - -void launch_compute_V( - torch::Tensor T_mat, torch::Tensor U, torch::Tensor V, - int B, int K, float tau_inv, torch::Tensor total_batch_tensor, int batch_offset -) { - int threads = 256; - int blocks = B; // Perfect mapping row to block - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (T_mat.dtype() == torch::kBFloat16) { - compute_V_kernel<<>>( - reinterpret_cast(T_mat.data_ptr()), - U.data_ptr(), V.data_ptr(), B, K, tau_inv, - total_batch_tensor.data_ptr(), batch_offset); - } else { - compute_V_kernel<<>>( - T_mat.data_ptr(), U.data_ptr(), V.data_ptr(), - B, K, tau_inv, total_batch_tensor.data_ptr(), batch_offset); - } -} - -void launch_compute_U( - torch::Tensor T_mat, torch::Tensor V, torch::Tensor local_U, - int B, int K, float tau_inv -) { - int threads = 256; - int blocks = (K + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (T_mat.dtype() == torch::kBFloat16) { - compute_U_kernel<<>>( - reinterpret_cast(T_mat.data_ptr()), - V.data_ptr(), local_U.data_ptr(), B, K, tau_inv); - } else { - compute_U_kernel<<>>( - T_mat.data_ptr(), V.data_ptr(), local_U.data_ptr(), - B, K, tau_inv); - } -} - -void launch_compute_Final( - torch::Tensor T_mat, torch::Tensor U, torch::Tensor V, torch::Tensor Out, - int B, int K, float tau_inv, torch::Tensor total_batch_tensor, int batch_offset -) { - int threads = 256; - int blocks = (B * K + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (T_mat.dtype() == torch::kBFloat16) { - compute_Final_kernel<<>>( - reinterpret_cast(T_mat.data_ptr()), - U.data_ptr(), V.data_ptr(), Out.data_ptr(), - B, K, tau_inv, total_batch_tensor.data_ptr(), batch_offset); - } else { - compute_Final_kernel<<>>( - T_mat.data_ptr(), U.data_ptr(), V.data_ptr(), Out.data_ptr(), - B, K, tau_inv, total_batch_tensor.data_ptr(), batch_offset); - } -} - -void launch_allreduce_f32( - torch::Tensor ptrs_tensor, torch::Tensor out, torch::Tensor signal_pad_ptrs_tensor, - int64_t n, int world_size, int rank, int64_t iteration_offset -) { - int threads = 256; - int blocks = std::min(1024, (int)((n + threads - 1) / threads)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const uint64_t* d_ptrs = reinterpret_cast(ptrs_tensor.data_ptr()); - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - - allreduce_f32_kernel<<>>( - d_ptrs, out.data_ptr(), d_signal, n, world_size, rank, iteration_offset); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_compute_U_init", &launch_compute_U_init); - m.def("launch_compute_V", &launch_compute_V); - m.def("launch_compute_U", &launch_compute_U); - m.def("launch_compute_Final", &launch_compute_Final); - m.def("launch_allreduce_f32", &launch_allreduce_f32); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("dinov2_sinkhorn_knopp_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(n: int, dtype: torch.dtype, device: torch.device, group: dist.ProcessGroup): - key = (n, dtype, device, group) - if key in _resource_cache: - return _resource_cache[key] - - buf = symm_mem.empty(n, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - - out = torch.empty(n, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, out, ptrs_tensor) - _resource_cache[key] = res - return res - - -@torch.no_grad() -def solution( - teacher_output: torch.Tensor, - teacher_temp: float, - n_masked_patches_tensor: torch.Tensor, - n_iterations: int = 3, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - - # Input invariants - B, K = teacher_output.shape - tau_inv = 1.0 / teacher_temp - device = teacher_output.device - ext = _get_ext() - teacher_output = teacher_output.contiguous() - - if n_masked_patches_tensor.dtype != torch.float32: - n_masked_patches_tensor = n_masked_patches_tensor.float() - - # K + 1 reserves a slot for parallel scalar reduction of batch patches in the identical buffer - symm_buf, hdl, global_U, ptrs_tensor = _get_resources(K + 1, torch.float32, device, group) - - V = torch.empty(B, dtype=torch.float32, device=device) - out = torch.empty((B, K), dtype=torch.float32, device=device) - - # Sync block execution before altering local mapped peers - dist.barrier(group=group) - - ext.launch_compute_U_init( - teacher_output, symm_buf, B, K, tau_inv, n_masked_patches_tensor - ) - - barrier_id = 0 - threads = 256 - blocks_allreduce_K1 = min(1024, (K + 1 + threads - 1) // threads) - - # Coalesced collective communication mapping the global mass - ext.launch_allreduce_f32( - ptrs_tensor, global_U, hdl.signal_pad_ptrs_dev, K + 1, - world_size, rank, barrier_id - ) - barrier_id += 2 * blocks_allreduce_K1 - blocks_allreduce_K = min(1024, (K + threads - 1) // threads) - - # Unrolled analytic matrix iterations (updating projection vectors inplace) - for _ in range(n_iterations - 1): - ext.launch_compute_V( - teacher_output, global_U, V, B, K, tau_inv, global_U, K - ) - ext.launch_compute_U( - teacher_output, V, symm_buf, B, K, tau_inv - ) - ext.launch_allreduce_f32( - ptrs_tensor, global_U, hdl.signal_pad_ptrs_dev, K, - world_size, rank, barrier_id - ) - barrier_id += 2 * blocks_allreduce_K - - ext.launch_compute_V( - teacher_output, global_U, V, B, K, tau_inv, global_U, K - ) - - # Final matrix assembly mapped back to transposed layout - ext.launch_compute_Final( - teacher_output, global_U, V, out, B, K, tau_inv, global_U, K - ) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/82_sam3_allgather_iou_suppression_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/82_sam3_allgather_iou_suppression_cuda.py deleted file mode 100755 index e95458c..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/82_sam3_allgather_iou_suppression_cuda.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -from typing import List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -_NO_OBJ_LOGIT = -10.0 - -CUDA_SRC = r''' -#include -#include -#include -#include - -// Fused gather, binarization, and area accumulation. -// Reads bfloat16 patches from peer symmetric memory pointers, converts to f32, and extracts area. -__global__ void fetch_and_prep_kernel( - const long long* __restrict__ ptrs_masks, - const long long* __restrict__ ptrs_scores, - const int* __restrict__ offsets, - float* __restrict__ masks_global, - float* __restrict__ scores_global, - float* __restrict__ binary_masks_flat, - float* __restrict__ areas, - int N_total, - int H_W, - int world_size -) { - int i = blockIdx.x; // object index - if (i >= N_total) return; - - // Find rank owner of the global object i - int j = 0; - while (j < world_size - 1 && i >= offsets[j+1]) { - j++; - } - int local_i = i - offsets[j]; - - const __nv_bfloat16* src_mask = (const __nv_bfloat16*)ptrs_masks[j] + local_i * H_W; - float* dst_mask = masks_global + i * H_W; - float* dst_bin = binary_masks_flat + i * H_W; - - int tid = threadIdx.x; - int stride = blockDim.x; - - float local_area = 0.0f; - - // Safely vectorize loads if alignment matches - bool can_vectorize = (((uintptr_t)src_mask) % 16 == 0) && (H_W % 8 == 0); - - if (can_vectorize) { - int num_vec = H_W / 8; - for (int k = tid; k < num_vec; k += stride) { - ulonglong2 vec = *(const ulonglong2*)(src_mask + k * 8); - const __nv_bfloat16* vals = (const __nv_bfloat16*)&vec; - - #pragma unroll - for (int v = 0; v < 8; v++) { - float val_f32 = __bfloat162float(vals[v]); - dst_mask[k * 8 + v] = val_f32; - float bin_val = val_f32 > 0.0f ? 1.0f : 0.0f; - dst_bin[k * 8 + v] = bin_val; - local_area += bin_val; - } - } - for (int k = num_vec * 8 + tid; k < H_W; k += stride) { - float val_f32 = __bfloat162float(src_mask[k]); - dst_mask[k] = val_f32; - float bin_val = val_f32 > 0.0f ? 1.0f : 0.0f; - dst_bin[k] = bin_val; - local_area += bin_val; - } - } else { - for (int k = tid; k < H_W; k += stride) { - float val_f32 = __bfloat162float(src_mask[k]); - dst_mask[k] = val_f32; - float bin_val = val_f32 > 0.0f ? 1.0f : 0.0f; - dst_bin[k] = bin_val; - local_area += bin_val; - } - } - - // Shared-memory block reduction to aggregate object area - static __shared__ float shared_area[256]; - shared_area[tid] = local_area; - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_area[tid] += shared_area[tid + s]; - } - __syncthreads(); - } - - // Assign area and load matching score sequentially - if (tid == 0) { - areas[i] = shared_area[0]; - const __nv_bfloat16* src_score = (const __nv_bfloat16*)ptrs_scores[j] + local_i; - scores_global[i] = __bfloat162float(*src_score); - } -} - -// Single-pass symmetric compute for boolean suppressions. Overlaps logic mask application. -__global__ void compute_and_apply_suppress_kernel( - const float* __restrict__ intersection, - const float* __restrict__ areas, - const int64_t* __restrict__ last_occluded, - bool* __restrict__ to_suppress_out, - float* __restrict__ masks_global, - float iou_threshold, - bool reverse, - int N_total, - int H_W, - float no_obj_logit -) { - int k = blockIdx.x; - if (k >= N_total) return; - - __shared__ bool suppress; - if (threadIdx.x == 0) { - suppress = false; - int64_t last_k = last_occluded[k]; - - for (int other = 0; other < N_total; other++) { - if (k == other) continue; - - int i = k < other ? k : other; - int j = k < other ? other : k; - - float inter = intersection[i * N_total + j]; - float union_area = areas[i] + areas[j] - inter; - if (union_area < 1.0f) union_area = 1.0f; - float iou = inter / union_area; - - if (iou >= iou_threshold) { - int64_t last_other = last_occluded[other]; - bool cmp = reverse ? (last_k < last_other) : (last_k > last_other); - if (cmp && last_other > -1) { - suppress = true; - break; - } - } - } - to_suppress_out[k] = suppress; - } - - __syncthreads(); - - // Mask with No-Object Logit in place across local threads if suppressed - if (suppress) { - for (int p = threadIdx.x; p < H_W; p += blockDim.x) { - masks_global[k * H_W + p] = no_obj_logit; - } - } -} - -void fetch_and_prep( - torch::Tensor ptrs_masks, - torch::Tensor ptrs_scores, - torch::Tensor offsets, - torch::Tensor masks_global, - torch::Tensor scores_global, - torch::Tensor binary_masks_flat, - torch::Tensor areas, - int N_total, - int H_W, - int world_size -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (N_total > 0) { - fetch_and_prep_kernel<<>>( - ptrs_masks.data_ptr(), - ptrs_scores.data_ptr(), - offsets.data_ptr(), - masks_global.data_ptr(), - scores_global.data_ptr(), - binary_masks_flat.data_ptr(), - areas.data_ptr(), - N_total, - H_W, - world_size - ); - } -} - -void compute_and_suppress( - torch::Tensor intersection, - torch::Tensor areas, - torch::Tensor last_occluded, - torch::Tensor to_suppress_out, - torch::Tensor masks_global, - float iou_threshold, - bool reverse, - int N_total, - int H_W, - float no_obj_logit -) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (N_total > 0) { - compute_and_apply_suppress_kernel<<>>( - intersection.data_ptr(), - areas.data_ptr(), - last_occluded.data_ptr(), - to_suppress_out.data_ptr(), - masks_global.data_ptr(), - iou_threshold, - reverse, - N_total, - H_W, - no_obj_logit - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fetch_and_prep", &fetch_and_prep, "Fetch and prepare inputs"); - m.def("compute_and_suppress", &compute_and_suppress, "Compute IoU logic and suppress"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("sam3_iou_suppress_ext", CUDA_SRC) - return _ext - - -_symm_cache = {} - -def _get_symm_state(max_N_local, H, W, device, group_id, group): - key = (max_N_local, H, W, device, group_id) - if key in _symm_cache: - return _symm_cache[key] - - # BF16 slices to halve the UVA symmetric bandwidth constraints over the fast NVLink loop - masks_symm = symm_mem.empty((max_N_local, H, W), dtype=torch.bfloat16, device=device) - scores_symm = symm_mem.empty((max_N_local,), dtype=torch.bfloat16, device=device) - - hdl_masks = symm_mem.rendezvous(masks_symm, group=group) - hdl_scores = symm_mem.rendezvous(scores_symm, group=group) - - ptrs_masks_tensor = torch.tensor(hdl_masks.buffer_ptrs, dtype=torch.int64, device=device) - ptrs_scores_tensor = torch.tensor(hdl_scores.buffer_ptrs, dtype=torch.int64, device=device) - - res = (masks_symm, scores_symm, hdl_masks, hdl_scores, ptrs_masks_tensor, ptrs_scores_tensor) - _symm_cache[key] = res - return res - - -@torch.no_grad() -def solution( - low_res_masks_local: torch.Tensor, - obj_scores_local: torch.Tensor, - num_obj_per_gpu: List[int], - last_occluded: torch.Tensor, - iou_threshold: float = 0.7, - reverse: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - group = group or dist.group.WORLD - rank = dist.get_rank(group=group) - world_size = dist.get_world_size(group=group) - device = low_res_masks_local.device - - if rank == 0: - _get_ext() - dist.barrier(group=group) - - expected = int(num_obj_per_gpu[rank]) - if low_res_masks_local.shape[0] != expected: - raise ValueError("local mask count does not match num_obj_per_gpu") - - H, W = low_res_masks_local.shape[1:] - H_W = H * W - N_total = sum(num_obj_per_gpu) - max_N_local = max(num_obj_per_gpu) if world_size > 0 else 0 - - if N_total == 0: - return ( - torch.empty((0, H, W), dtype=torch.float32, device=device), - torch.empty((0,), dtype=torch.float32, device=device), - torch.empty((0,), dtype=torch.bool, device=device) - ) - - masks_symm, scores_symm, hdl_masks, hdl_scores, ptrs_masks, ptrs_scores = _get_symm_state( - max_N_local, H, W, device, id(group), group - ) - - if expected > 0: - masks_symm[:expected].copy_(low_res_masks_local) - scores_symm[:expected].copy_(obj_scores_local) - - # Safe asynchronous block waiting for peer stream writes into symmetrical UVA buffers - hdl_masks.barrier(channel=0) - - masks_global = torch.empty((N_total, H, W), dtype=torch.float32, device=device) - scores_global = torch.empty((N_total,), dtype=torch.float32, device=device) - binary_masks_flat = torch.empty((N_total, H_W), dtype=torch.float32, device=device) - areas = torch.empty((N_total,), dtype=torch.float32, device=device) - to_suppress = torch.zeros((N_total,), dtype=torch.bool, device=device) - - offsets = [0] * (world_size + 1) - for i in range(world_size): - offsets[i+1] = offsets[i] + num_obj_per_gpu[i] - offsets_tensor = torch.tensor(offsets, dtype=torch.int32, device=device) - - last_occluded = last_occluded.to(device=device, dtype=torch.int64) - - # Fast UVA read patch overlaps cleanly with format conversion to f32 binarization properties - _get_ext().fetch_and_prep( - ptrs_masks, ptrs_scores, offsets_tensor, - masks_global, scores_global, binary_masks_flat, areas, - N_total, H_W, world_size - ) - - if N_total > 1: - # Standard highly optimized cublas handles precision dense pairwise comparisons fast-path - intersection = torch.mm(binary_masks_flat, binary_masks_flat.t()) - - # Inline checks bounding intersections alongside suppression logic assignments - _get_ext().compute_and_suppress( - intersection, areas, last_occluded, to_suppress, masks_global, - float(iou_threshold), bool(reverse), N_total, H_W, _NO_OBJ_LOGIT - ) - - return masks_global, scores_global, to_suppress \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/83_vocab_parallel_log_prob_topk_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/83_vocab_parallel_log_prob_topk_cuda.py deleted file mode 100755 index 5c8fbe0..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/83_vocab_parallel_log_prob_topk_cuda.py +++ /dev/null @@ -1,410 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void gather_and_init_kernel_2d( - const long long* peer_ptrs_int64, - float* __restrict__ out_logits, - int* __restrict__ out_indices, - int local_vocab, - int world_size, - int rank, - int local_tokens, - int start_token, - int chunk_size -) { - int chunk_token_idx = blockIdx.x; - int vocab_idx = blockIdx.y * blockDim.x + threadIdx.x; - int V = world_size * local_vocab; - - if (vocab_idx < V) { - int peer = vocab_idx / local_vocab; - int peer_vocab_idx = vocab_idx % local_vocab; - - int global_token_idx = rank * local_tokens + start_token + chunk_token_idx; - - const __nv_bfloat16* peer_ptr = reinterpret_cast(peer_ptrs_int64[peer]); - __nv_bfloat16 val = peer_ptr[global_token_idx * local_vocab + peer_vocab_idx]; - - int out_idx = chunk_token_idx * V + vocab_idx; - out_logits[out_idx] = __bfloat162float(val); - out_indices[out_idx] = vocab_idx; - } -} - -__global__ void init_offsets_kernel(int* offsets, int num_segments, int V) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx <= num_segments) { - offsets[idx] = idx * V; - } -} - -__global__ void filter_and_logprob_kernel( - const float* __restrict__ sorted_logits, - const int* __restrict__ sorted_indices, - const int64_t* __restrict__ target_local, - float* __restrict__ out_logprobs, - int V, - int top_k, - float top_p -) { - int token_idx = blockIdx.x; - int tid = threadIdx.x; - - const float* token_logits = sorted_logits + token_idx * V; - const int* token_indices = sorted_indices + token_idx * V; - int target = target_local[token_idx]; - - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockScan BlockScan; - __shared__ union { - typename BlockReduce::TempStorage reduce; - typename BlockScan::TempStorage scan; - } temp_storage; - - __shared__ float s_max_val; - __shared__ float s_threshold_k; - __shared__ float s_block_sum; - __shared__ float s_scan_carry; - __shared__ float s_target_val; - __shared__ bool s_target_kept; - - if (tid == 0) { - s_max_val = token_logits[V - 1]; // Sorted ascending, max is at the end - s_threshold_k = -INFINITY; - if (top_k > 0 && top_k <= V) { - s_threshold_k = token_logits[V - top_k]; - } - s_scan_carry = 0; - s_target_val = -INFINITY; - s_target_kept = false; - } - __syncthreads(); - - float max_val = s_max_val; - float threshold_k = s_threshold_k; - - // Pass 1: Sum exponentials for top-k elements - float thread_sum = 0; - for (int i = tid; i < V; i += blockDim.x) { - float val = token_logits[i]; - if (val >= threshold_k) { - thread_sum += expf(val - max_val); - } - } - - float sum_exp = BlockReduce(temp_storage.reduce).Sum(thread_sum); - if (tid == 0) { - s_block_sum = sum_exp; - } - __syncthreads(); - - sum_exp = s_block_sum; - - // Pass 2: Block-level scan for top-p and final sum_exp - float thread_final_sum = 0; - - for (int chunk = 0; chunk < V; chunk += blockDim.x) { - int i = chunk + tid; - float val = -INFINITY; - float prob = 0; - bool valid_k = false; - int orig_idx = -1; - - if (i < V) { - val = token_logits[i]; - orig_idx = token_indices[i]; - if (val >= threshold_k) { - valid_k = true; - prob = expf(val - max_val) / sum_exp; - } - } - - float chunk_cumsum; - BlockScan(temp_storage.scan).InclusiveSum(prob, chunk_cumsum); - __syncthreads(); - - float global_cumsum = s_scan_carry + chunk_cumsum; - - if (i < V && valid_k) { - bool keep_p = true; - if (top_p < 1.0f) { - keep_p = (global_cumsum > 1.0f - top_p) || (i == V - 1); - } - if (keep_p) { - thread_final_sum += expf(val - max_val); - } - if (orig_idx == target) { - if (keep_p) { - s_target_val = val; - s_target_kept = true; - } - } - } - - if (tid == blockDim.x - 1) { - s_scan_carry += chunk_cumsum; - } - __syncthreads(); - } - - float final_sum_exp = BlockReduce(temp_storage.reduce).Sum(thread_final_sum); - __shared__ float s_final_sum_exp; - if (tid == 0) { - s_final_sum_exp = final_sum_exp; - } - __syncthreads(); - - if (tid == 0) { - if (s_target_kept) { - out_logprobs[token_idx] = s_target_val - (max_val + logf(s_final_sum_exp)); - } else { - out_logprobs[token_idx] = -INFINITY; - } - } -} - -__global__ void gather_logprobs_kernel( - const long long* peer_ptrs_int64, - float* __restrict__ out, - int local_tokens, - int world_size -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = world_size * local_tokens; - if (idx < total) { - int peer = idx / local_tokens; - int token_idx = idx % local_tokens; - const float* peer_ptr = reinterpret_cast(peer_ptrs_int64[peer]); - out[idx] = peer_ptr[token_idx]; - } -} - -void gather_and_init( - torch::Tensor peer_ptrs, - torch::Tensor out_logits, - torch::Tensor out_indices, - int local_vocab, - int world_size, - int rank, - int local_tokens, - int start_token, - int chunk_size -) { - const long long* d_peers = (const long long*)peer_ptrs.data_ptr(); - float* d_out_logits = out_logits.data_ptr(); - int* d_out_indices = out_indices.data_ptr(); - - int V = world_size * local_vocab; - dim3 block(256); - dim3 grid(chunk_size, (V + 255)/256); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - gather_and_init_kernel_2d<<>>( - d_peers, d_out_logits, d_out_indices, local_vocab, world_size, - rank, local_tokens, start_token, chunk_size - ); -} - -void init_offsets(torch::Tensor offsets, int num_segments, int V) { - int* d_offsets = offsets.data_ptr(); - int threads = 256; - int blocks = (num_segments + 1 + threads - 1) / threads; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - init_offsets_kernel<<>>(d_offsets, num_segments, V); -} - -void cub_segmented_sort( - torch::Tensor keys_in, - torch::Tensor keys_out, - torch::Tensor vals_in, - torch::Tensor vals_out, - torch::Tensor offsets, - int num_segments -) { - size_t temp_storage_bytes = 0; - float* d_keys_in = keys_in.data_ptr(); - float* d_keys_out = keys_out.data_ptr(); - int* d_vals_in = vals_in.data_ptr(); - int* d_vals_out = vals_out.data_ptr(); - int* d_offsets = offsets.data_ptr(); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - // Dry-run (O(1) host overhead) - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, temp_storage_bytes, - d_keys_in, d_keys_out, d_vals_in, d_vals_out, - keys_in.numel(), num_segments, d_offsets, d_offsets + 1, - 0, sizeof(float)*8, stream - ); - - auto temp_storage = torch::empty({(long)temp_storage_bytes}, keys_in.options().dtype(torch::kUInt8)); - - cub::DeviceSegmentedRadixSort::SortPairs( - temp_storage.data_ptr(), temp_storage_bytes, - d_keys_in, d_keys_out, d_vals_in, d_vals_out, - keys_in.numel(), num_segments, d_offsets, d_offsets + 1, - 0, sizeof(float)*8, stream - ); -} - -void filter_and_logprob( - torch::Tensor sorted_logits, - torch::Tensor sorted_indices, - torch::Tensor target_local, - torch::Tensor out_logprobs, - int V, - int top_k, - float top_p -) { - int chunk_size = out_logprobs.numel(); - if (chunk_size == 0) return; - - dim3 block(512); - dim3 grid(chunk_size); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - filter_and_logprob_kernel<<>>( - sorted_logits.data_ptr(), - sorted_indices.data_ptr(), - target_local.data_ptr(), - out_logprobs.data_ptr(), - V, top_k, top_p - ); -} - -void gather_logprobs( - torch::Tensor peer_ptrs, - torch::Tensor out, - int local_tokens, - int world_size -) { - const long long* d_peers = (const long long*)peer_ptrs.data_ptr(); - float* d_out = out.data_ptr(); - - int total = world_size * local_tokens; - int threads = 256; - int blocks = (total + threads - 1) / threads; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - gather_logprobs_kernel<<>>( - d_peers, d_out, local_tokens, world_size - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gather_and_init", &gather_and_init); - m.def("init_offsets", &init_offsets); - m.def("cub_segmented_sort", &cub_segmented_sort); - m.def("filter_and_logprob", &filter_and_logprob); - m.def("gather_logprobs", &gather_logprobs); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("vp_logprob_topk_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_state(name, shape, dtype, device, group): - key = (name, tuple(shape), dtype, device) - if key in _symm_cache: - return _symm_cache[key] - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - ptrs = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, ptrs) - _symm_cache[key] = res - return res - -@torch.no_grad() -def solution( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - tp_group: Optional[dist.ProcessGroup] = None, - top_k: Optional[int] = None, - top_p: float = 1.0, -) -> torch.Tensor: - tp_group = tp_group or dist.group.WORLD - world_size = dist.get_world_size(tp_group) - rank = dist.get_rank(tp_group) - - batch, seq_len, local_vocab = vocab_parallel_logits.shape - num_tokens = batch * seq_len - if num_tokens % world_size != 0: - raise ValueError(f"B*S={num_tokens} must be divisible by tensor parallel size {world_size}") - - local_tokens = num_tokens // world_size - V = world_size * local_vocab - device = vocab_parallel_logits.device - - target_local = target.reshape(-1)[rank * local_tokens : (rank + 1) * local_tokens].contiguous() - target_local = target_local.to(device=device, dtype=torch.int64) - - buf_logits, hdl_logits, ptrs_logits = _get_symm_state("logits", vocab_parallel_logits.shape, vocab_parallel_logits.dtype, device, tp_group) - buf_logprobs, hdl_logprobs, ptrs_logprobs = _get_symm_state("logprobs", (local_tokens,), torch.float32, device, tp_group) - - buf_logits.copy_(vocab_parallel_logits) - hdl_logits.barrier(channel=0) - - # 2-stage chunking over execution streams for overlapping execution & UVA reads - num_chunks = 2 - if local_tokens < num_chunks or local_tokens % num_chunks != 0: - num_chunks = 1 - chunk_size = local_tokens // num_chunks - - ext = _get_ext() - if not hasattr(ext, 'streams'): - ext.streams = [torch.cuda.Stream() for _ in range(num_chunks)] - streams = ext.streams[:num_chunks] - - local_logits_keys = torch.empty((local_tokens, V), device=device, dtype=torch.float32) - local_indices_vals = torch.empty((local_tokens, V), device=device, dtype=torch.int32) - sorted_keys = torch.empty_like(local_logits_keys) - sorted_vals = torch.empty_like(local_indices_vals) - - offsets = torch.empty(chunk_size + 1, device=device, dtype=torch.int32) - ext.init_offsets(offsets, chunk_size, V) - - default_stream = torch.cuda.current_stream() - - for i in range(num_chunks): - with torch.cuda.stream(streams[i]): - streams[i].wait_stream(default_stream) - start = i * chunk_size - end = start + chunk_size - - keys_in = local_logits_keys[start:end] - vals_in = local_indices_vals[start:end] - keys_out = sorted_keys[start:end] - vals_out = sorted_vals[start:end] - tgt_in = target_local[start:end] - lp_out = buf_logprobs[start:end] - - ext.gather_and_init(ptrs_logits, keys_in, vals_in, local_vocab, world_size, rank, local_tokens, start, chunk_size) - ext.cub_segmented_sort(keys_in, keys_out, vals_in, vals_out, offsets, chunk_size) - ext.filter_and_logprob(keys_out, vals_out, tgt_in, lp_out, V, top_k if top_k is not None else 0, top_p) - - for s in streams: - default_stream.wait_stream(s) - - hdl_logprobs.barrier(channel=0) - - out_logprobs = torch.empty(world_size * local_tokens, device=device, dtype=torch.float32) - ext.gather_logprobs(ptrs_logprobs, out_logprobs, local_tokens, world_size) - - return out_logprobs.reshape(batch, seq_len) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/84_vocab_parallel_log_prob_topk_chunked_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/84_vocab_parallel_log_prob_topk_chunked_cuda.py deleted file mode 100755 index 84d3f54..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/84_vocab_parallel_log_prob_topk_chunked_cuda.py +++ /dev/null @@ -1,354 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Kernel 1: Push logits directly to peers' symmetric memory (All-to-All + Permute) -// --------------------------------------------------------------------------- -__global__ void push_logits_kernel_vec( - const __nv_bfloat16* __restrict__ local_logits, - const uint64_t* __restrict__ recv_buf_ptrs, - int current_tokens, - int local_vocab, - int world_size, - int my_rank -) { - int local_tokens = current_tokens / world_size; - int total_sends = world_size * local_tokens; - - int send_idx = blockIdx.x; - if (send_idx >= total_sends) return; - - int dst_rank = send_idx / local_tokens; - int t_local = send_idx % local_tokens; - - int t_chunk = dst_rank * local_tokens + t_local; - - const __nv_bfloat16* src = local_logits + t_chunk * local_vocab; - // Destination offset inherently accomplishes the sequence interleave permute - __nv_bfloat16* dst = reinterpret_cast<__nv_bfloat16*>(recv_buf_ptrs[dst_rank]) - + t_local * (world_size * local_vocab) - + my_rank * local_vocab; - - if (local_vocab % 8 == 0) { - int num_vec = local_vocab / 8; - const uint4* src_vec = reinterpret_cast(src); - uint4* dst_vec = reinterpret_cast(dst); - for (int i = threadIdx.x; i < num_vec; i += blockDim.x) { - dst_vec[i] = src_vec[i]; - } - } else if (local_vocab % 4 == 0) { - int num_vec = local_vocab / 4; - const uint64_t* src_vec = reinterpret_cast(src); - uint64_t* dst_vec = reinterpret_cast(dst); - for (int i = threadIdx.x; i < num_vec; i += blockDim.x) { - dst_vec[i] = src_vec[i]; - } - } else { - for (int i = threadIdx.x; i < local_vocab; i += blockDim.x) { - dst[i] = src[i]; - } - } -} - -// --------------------------------------------------------------------------- -// Kernel 2: Fused log_softmax, target gather, and All-Gather push -// --------------------------------------------------------------------------- -__global__ void fused_log_softmax_gather_push_kernel( - const __nv_bfloat16* __restrict__ filtered_logits, - const int64_t* __restrict__ target_local, - const uint64_t* __restrict__ logprobs_buf_ptrs, - int local_tokens, - int full_vocab, - int world_size, - int my_rank -) { - int t_local = blockIdx.x; - if (t_local >= local_tokens) return; - - int64_t target_idx = target_local[t_local]; - const __nv_bfloat16* logits = filtered_logits + t_local * full_vocab; - - // 1. Thread-local max - float thread_max = -1e20f; - for (int i = threadIdx.x; i < full_vocab; i += blockDim.x) { - float val = __bfloat162float(logits[i]); - if (val > thread_max) thread_max = val; - } - - // 2. Block-wide max reduction - static __shared__ float shared_max[32]; - int lane = threadIdx.x % 32; - int warp = threadIdx.x / 32; - - float warp_max = thread_max; - for (int offset = 16; offset > 0; offset /= 2) { - warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset)); - } - if (lane == 0) shared_max[warp] = warp_max; - __syncthreads(); - - float block_max = -1e20f; - if (threadIdx.x < (blockDim.x / 32)) { - block_max = shared_max[threadIdx.x]; - } - for (int offset = 16; offset > 0; offset /= 2) { - block_max = fmaxf(block_max, __shfl_down_sync(0xffffffff, block_max, offset)); - } - if (threadIdx.x == 0) shared_max[0] = block_max; - __syncthreads(); - block_max = shared_max[0]; - - // 3. Thread-local sum and target extraction - float thread_sum = 0.0f; - float target_val = 0.0f; - for (int i = threadIdx.x; i < full_vocab; i += blockDim.x) { - float val = __bfloat162float(logits[i]); - if (i == target_idx) target_val = val; - thread_sum += expf(val - block_max); - } - - // 4. Block-wide sum reduction - static __shared__ float shared_sum[32]; - static __shared__ float shared_target[32]; - - float warp_sum = thread_sum; - float warp_target = target_val; - for (int offset = 16; offset > 0; offset /= 2) { - warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset); - warp_target += __shfl_down_sync(0xffffffff, warp_target, offset); - } - if (lane == 0) { - shared_sum[warp] = warp_sum; - shared_target[warp] = warp_target; - } - __syncthreads(); - - float block_sum = 0.0f; - float block_target = 0.0f; - if (threadIdx.x < (blockDim.x / 32)) { - block_sum = shared_sum[threadIdx.x]; - block_target = shared_target[threadIdx.x]; - } - for (int offset = 16; offset > 0; offset /= 2) { - block_sum += __shfl_down_sync(0xffffffff, block_sum, offset); - block_target += __shfl_down_sync(0xffffffff, block_target, offset); - } - if (threadIdx.x == 0) shared_sum[0] = block_sum; - if (threadIdx.x == 0) shared_target[0] = block_target; - __syncthreads(); - - block_sum = shared_sum[0]; - block_target = shared_target[0]; - - // 5. Final log-prob logic and P2P All-Gather - if (threadIdx.x == 0) { - float log_prob = block_target - (block_max + logf(block_sum)); - int global_token_idx = my_rank * local_tokens + t_local; - - for (int dst = 0; dst < world_size; ++dst) { - float* dst_ptr = reinterpret_cast(logprobs_buf_ptrs[dst]); - dst_ptr[global_token_idx] = log_prob; - } - } -} - -// --------------------------------------------------------------------------- -// Python Bindings -// --------------------------------------------------------------------------- -void launch_push_logits( - torch::Tensor local_logits, - torch::Tensor recv_buf_ptrs, - int current_tokens, - int local_vocab, - int world_size, - int my_rank -) { - int local_tokens = current_tokens / world_size; - int total_sends = world_size * local_tokens; - if (total_sends == 0) return; - - int threads = 256; - int blocks = total_sends; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - push_logits_kernel_vec<<>>( - reinterpret_cast(local_logits.data_ptr()), - reinterpret_cast(recv_buf_ptrs.data_ptr()), - current_tokens, - local_vocab, - world_size, - my_rank - ); -} - -void launch_fused_log_softmax_gather_push( - torch::Tensor filtered_logits, - torch::Tensor target_local, - torch::Tensor logprobs_buf_ptrs, - int local_tokens, - int full_vocab, - int world_size, - int my_rank -) { - if (local_tokens == 0) return; - - int threads = 256; - int blocks = local_tokens; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - fused_log_softmax_gather_push_kernel<<>>( - reinterpret_cast(filtered_logits.data_ptr()), - target_local.data_ptr(), - reinterpret_cast(logprobs_buf_ptrs.data_ptr()), - local_tokens, - full_vocab, - world_size, - my_rank - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_push_logits", &launch_push_logits); - m.def("launch_fused_log_softmax_gather_push", &launch_fused_log_softmax_gather_push); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("vocab_parallel_logprob_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def _get_symm_buffers(chunk_tokens, local_vocab, world_size, dtype, device, group): - key = (chunk_tokens, local_vocab, world_size, dtype, device) - if key in _symm_cache: - return _symm_cache[key] - - max_local_tokens = chunk_tokens // world_size - full_vocab = local_vocab * world_size - - recv_buf = symm_mem.empty((max_local_tokens, full_vocab), dtype=dtype, device=device) - recv_hdl = symm_mem.rendezvous(recv_buf, group=group) - recv_ptrs = torch.tensor(recv_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - logprobs_buf = symm_mem.empty((chunk_tokens,), dtype=torch.float32, device=device) - logprobs_hdl = symm_mem.rendezvous(logprobs_buf, group=group) - logprobs_ptrs = torch.tensor(logprobs_hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (recv_buf, recv_hdl, recv_ptrs, logprobs_buf, logprobs_hdl, logprobs_ptrs) - _symm_cache[key] = res - return res - -def _apply_top_k_top_p(logits: torch.Tensor, top_k: Optional[int], top_p: float) -> torch.Tensor: - need_k = top_k is not None and top_k > 0 - need_p = top_p is not None and top_p < 1.0 - if not need_k and not need_p: - return logits - - original_shape = logits.shape - vocab_size = logits.shape[-1] - logits_2d = logits.reshape(-1, vocab_size) - - if need_k: - top_k = min(int(top_k), vocab_size) - - if need_k and not need_p: - top_k_values, _ = torch.topk(logits_2d, top_k, dim=-1) - threshold = top_k_values[..., -1:].expand_as(logits_2d) - filtered = logits_2d.masked_fill(logits_2d < threshold, float("-inf")) - return filtered.reshape(original_shape) - - sorted_logits, sorted_idx = logits_2d.sort(dim=-1, descending=False) - - if need_k: - top_k_index = sorted_logits.shape[-1] - top_k - threshold = sorted_logits[..., top_k_index : top_k_index + 1] - sorted_logits = sorted_logits.masked_fill(sorted_logits < threshold, float("-inf")) - - sorted_probs = sorted_logits.softmax(dim=-1) - top_p_mask = torch.cumsum(sorted_probs, dim=-1) > 1 - top_p - top_p_mask[..., -1] = True - sorted_logits = sorted_logits.masked_fill(~top_p_mask, float("-inf")) - filtered = sorted_logits.scatter(dim=-1, index=sorted_idx, src=sorted_logits) - return filtered.reshape(original_shape) - - -@torch.no_grad() -def solution( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - tp_group: Optional[dist.ProcessGroup] = None, - top_k: Optional[int] = None, - top_p: float = 1.0, - chunk_size: int = 1, -) -> torch.Tensor: - tp_group = tp_group or dist.group.WORLD - world_size = dist.get_world_size(group=tp_group) - rank = dist.get_rank(group=tp_group) - batch, seq_len, local_vocab = vocab_parallel_logits.shape - - num_tokens = batch * seq_len - chunk_tokens = batch * max(1, int(chunk_size)) - - if num_tokens % world_size != 0: - raise ValueError(f"B*S={num_tokens} must be divisible by tensor parallel size {world_size}") - if chunk_tokens % world_size != 0: - raise ValueError(f"B*chunk_size={chunk_tokens} must be divisible by tp size {world_size}") - - ext = _get_ext() - - recv_buf, recv_hdl, recv_ptrs, logprobs_buf, logprobs_hdl, logprobs_ptrs = _get_symm_buffers( - chunk_tokens, local_vocab, world_size, vocab_parallel_logits.dtype, vocab_parallel_logits.device, tp_group - ) - - logits_2d = vocab_parallel_logits.reshape(num_tokens, local_vocab) - target_flat = target.reshape(-1) - - full_vocab = local_vocab * world_size - pieces = [] - - for start in range(0, num_tokens, chunk_tokens): - end = min(start + chunk_tokens, num_tokens) - current = end - start - local_tokens = current // world_size - - logits_chunk = logits_2d[start:end].contiguous() - target_chunk = target_flat[start:end] - target_local = target_chunk[rank * local_tokens : (rank + 1) * local_tokens].contiguous().long() - - # Step 1: Push symmetric UVA vocabulary transpose - ext.launch_push_logits( - logits_chunk, recv_ptrs, current, local_vocab, world_size, rank - ) - - # Synchronize streams across TP-domain before relying on fetched symmetric data - recv_hdl.barrier(channel=0) - - # Step 2: Extract locally active sequence chunk and apply PyTorch complex filtering - seq_logits = recv_buf[:local_tokens] - filtered = _apply_top_k_top_p(seq_logits, top_k=top_k, top_p=top_p) - - # Step 3: Fast fused reduction yielding target token prob -> broadcast directly to peer buffers - ext.launch_fused_log_softmax_gather_push( - filtered, target_local, logprobs_ptrs, local_tokens, full_vocab, world_size, rank - ) - - # Ensure block reductions are flushed globally safely - logprobs_hdl.barrier(channel=0) - - pieces.append(logprobs_buf[:current].clone()) - - return torch.cat(pieces, dim=0).reshape(batch, seq_len) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/85_vocab_parallel_log_prob_topk_chunked_backward_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/85_vocab_parallel_log_prob_topk_chunked_backward_cuda.py deleted file mode 100755 index 1792b2d..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/85_vocab_parallel_log_prob_topk_chunked_backward_cuda.py +++ /dev/null @@ -1,390 +0,0 @@ -""" -Strategy: -1. Replaced `torch.distributed.all_to_all_single` with direct peer-to-peer gathers and scatters using `torch.distributed._symmetric_memory` and NVLink. -2. Fused the gradient computation, target masking, and scatter phase into a single custom CUDA kernel (`fused_grad_scatter_kernel`). This avoids materializing and reading the large intermediate gradient sequence in device memory. -3. Implemented a chunked double-buffering pipeline with asynchronous CUDA streams and events. This hides the P2P communication and compute of the current chunk behind the D2D copying of the next/previous chunks. -4. Used vectorized memory accesses (`int4`) in the gather kernel when `local_vocab` is aligned, maximizing the Hopper architecture's memory bandwidth utilization. -""" - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Tuple - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void gather_vp_to_seq_kernel( - const int64_t* __restrict__ ptrs, - int64_t offset_elements, - at::BFloat16* __restrict__ out, - int local_tokens, - int local_vocab, - int world_size, - int rank -) { - int64_t total_elements = (int64_t)local_tokens * world_size * local_vocab; - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < total_elements) { - int v = idx % local_vocab; - int p = (idx / local_vocab) % world_size; - int t = idx / (local_vocab * world_size); - - int peer_token_idx = rank * local_tokens + t; - int64_t peer_offset = peer_token_idx * local_vocab + v; - - const at::BFloat16* peer_ptr = (const at::BFloat16*)ptrs[p] + offset_elements; - out[idx] = peer_ptr[peer_offset]; - } -} - -__global__ void gather_vp_to_seq_kernel_vec8( - const int64_t* __restrict__ ptrs, - int64_t offset_elements, - at::BFloat16* __restrict__ out, - int local_tokens, - int local_vocab, - int world_size, - int rank -) { - int64_t total_vecs = ((int64_t)local_tokens * world_size * local_vocab) / 8; - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < total_vecs) { - int vec_v = idx % (local_vocab / 8); - int p = (idx / (local_vocab / 8)) % world_size; - int t = idx / ((local_vocab / 8) * world_size); - - int peer_token_idx = rank * local_tokens + t; - int64_t peer_offset = peer_token_idx * (local_vocab / 8) + vec_v; - - const int4* peer_ptr = (const int4*)( ((const at::BFloat16*)ptrs[p]) + offset_elements ); - int4* out_ptr = (int4*)out; - - out_ptr[idx] = peer_ptr[peer_offset]; - } -} - -__global__ void fused_grad_scatter_kernel( - const float* __restrict__ probs, - const int64_t* __restrict__ target_local, - const at::BFloat16* __restrict__ grad_local, - const bool* __restrict__ keep_mask, - const int64_t* __restrict__ ptrs, - int64_t offset_elements, - int local_tokens, - int local_vocab, - int world_size, - int rank -) { - int64_t total_elements = (int64_t)local_tokens * world_size * local_vocab; - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx < total_elements) { - int v = idx % local_vocab; - int p = (idx / local_vocab) % world_size; - int t = idx / (local_vocab * world_size); - - int global_v = p * local_vocab + v; - - float p_val = probs[idx]; - float g = -p_val; - - int64_t target_idx = target_local[t]; - if (target_idx < 0) { - target_idx += (world_size * local_vocab); - } - - if (global_v == target_idx) { - g += 1.0f; - } - - float grad_out_val = __bfloat162float(grad_local[t]); - g *= grad_out_val; - - if (keep_mask != nullptr) { - if (!keep_mask[idx]) { - g = 0.0f; - } - } - - int peer_token_idx = rank * local_tokens + t; - int64_t peer_offset = peer_token_idx * local_vocab + v; - - at::BFloat16* peer_ptr = (at::BFloat16*)ptrs[p] + offset_elements; - peer_ptr[peer_offset] = __float2bfloat16(g); - } -} - -void launch_gather( - torch::Tensor ptrs, - int64_t offset_elements, - torch::Tensor out, - int local_tokens, - int local_vocab, - int world_size, - int rank -) { - int64_t total_elements = (int64_t)local_tokens * world_size * local_vocab; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (local_vocab % 8 == 0 && total_elements % 8 == 0) { - int threads = 256; - int blocks = (total_elements / 8 + threads - 1) / threads; - gather_vp_to_seq_kernel_vec8<<>>( - ptrs.data_ptr(), - offset_elements, - (at::BFloat16*)out.data_ptr(), - local_tokens, local_vocab, world_size, rank - ); - } else { - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - gather_vp_to_seq_kernel<<>>( - ptrs.data_ptr(), - offset_elements, - (at::BFloat16*)out.data_ptr(), - local_tokens, local_vocab, world_size, rank - ); - } -} - -void launch_fused_scatter( - torch::Tensor probs, - torch::Tensor target_local, - torch::Tensor grad_local, - std::optional keep_mask, - torch::Tensor ptrs, - int64_t offset_elements, - int local_tokens, - int local_vocab, - int world_size, - int rank -) { - int64_t total_elements = (int64_t)local_tokens * world_size * local_vocab; - int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const bool* mask_ptr = nullptr; - if (keep_mask.has_value()) { - mask_ptr = keep_mask.value().data_ptr(); - } - - fused_grad_scatter_kernel<<>>( - probs.data_ptr(), - target_local.data_ptr(), - (const at::BFloat16*)grad_local.data_ptr(), - mask_ptr, - ptrs.data_ptr(), - offset_elements, - local_tokens, local_vocab, world_size, rank - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_gather", &launch_gather); - m.def("launch_fused_scatter", &launch_fused_scatter); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("chunked_vp_backward_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_symm_buffers(chunk_tokens, local_vocab, dtype, device, group): - key = (chunk_tokens, local_vocab, dtype, device, group) - if key in _symm_cache: - return _symm_cache[key] - - symm_logits = symm_mem.empty((2, chunk_tokens, local_vocab), dtype=dtype, device=device) - hdl_logits = symm_mem.rendezvous(symm_logits, group=group) - logits_ptrs = torch.tensor(hdl_logits.buffer_ptrs, device=device, dtype=torch.int64) - - symm_grad = symm_mem.empty((2, chunk_tokens, local_vocab), dtype=dtype, device=device) - hdl_grad = symm_mem.rendezvous(symm_grad, group=group) - grad_ptrs = torch.tensor(hdl_grad.buffer_ptrs, device=device, dtype=torch.int64) - - res = (symm_logits, hdl_logits, logits_ptrs, symm_grad, hdl_grad, grad_ptrs) - _symm_cache[key] = res - return res - -_stream_cache = {} - -def _get_streams(device): - if device not in _stream_cache: - _stream_cache[device] = (torch.cuda.Stream(device=device), torch.cuda.Stream(device=device)) - return _stream_cache[device] - - -def _apply_top_k_top_p( - logits: torch.Tensor, - top_k: Optional[int], - top_p: float, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - need_k = top_k is not None and top_k > 0 - need_p = top_p is not None and top_p < 1.0 - if not need_k and not need_p: - return logits, None - - original_shape = logits.shape - vocab_size = logits.shape[-1] - logits_2d = logits.reshape(-1, vocab_size) - if need_k: - top_k = min(int(top_k), vocab_size) - - if need_k and not need_p: - top_k_values, _ = torch.topk(logits_2d, top_k, dim=-1) - threshold = top_k_values[..., -1:].expand_as(logits_2d) - keep_mask = logits_2d >= threshold - filtered = logits_2d.masked_fill(~keep_mask, float("-inf")) - return filtered.reshape(original_shape), keep_mask.reshape(original_shape) - - sorted_logits, sorted_idx = logits_2d.sort(dim=-1, descending=False) - top_k_mask = None - if need_k: - top_k_index = sorted_logits.shape[-1] - top_k - threshold = sorted_logits[..., top_k_index : top_k_index + 1] - top_k_mask = sorted_logits >= threshold - sorted_logits = sorted_logits.masked_fill(~top_k_mask, float("-inf")) - - sorted_probs = sorted_logits.softmax(dim=-1) - top_p_mask = torch.cumsum(sorted_probs, dim=-1) > 1 - top_p - top_p_mask[..., -1] = True - sorted_logits = sorted_logits.masked_fill(~top_p_mask, float("-inf")) - - keep_sorted = top_p_mask if top_k_mask is None else top_p_mask & top_k_mask - filtered = sorted_logits.scatter(dim=-1, index=sorted_idx, src=sorted_logits) - keep_mask = keep_sorted.scatter(dim=-1, index=sorted_idx, src=keep_sorted) - return filtered.reshape(original_shape), keep_mask.reshape(original_shape) - - -@torch.no_grad() -def solution( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - grad_output: torch.Tensor, - tp_group: Optional[dist.ProcessGroup] = None, - top_k: Optional[int] = None, - top_p: float = 1.0, - chunk_size: int = 1, -) -> torch.Tensor: - - tp_group = tp_group or dist.group.WORLD - world_size = dist.get_world_size(group=tp_group) - rank = dist.get_rank(group=tp_group) - - if rank == 0: - _get_ext() - dist.barrier(group=tp_group) - - batch, seq_len, local_vocab = vocab_parallel_logits.shape - num_tokens = batch * seq_len - chunk_tokens = batch * max(1, int(chunk_size)) - - if num_tokens % world_size != 0: - raise ValueError(f"B*S={num_tokens} must be divisible by tensor parallel size {world_size}") - if chunk_tokens % world_size != 0: - raise ValueError(f"B*chunk_size={chunk_tokens} must be divisible by tp size {world_size}") - - device = vocab_parallel_logits.device - dtype = vocab_parallel_logits.dtype - - logits_2d = vocab_parallel_logits.contiguous().reshape(num_tokens, local_vocab) - target_flat = target.contiguous().reshape(-1) - grad_flat = grad_output.contiguous().reshape(-1) - - grad_logits_2d = torch.empty_like(logits_2d) - - symm_logits, hdl_logits, logits_ptrs, symm_grad, hdl_grad, grad_ptrs = _get_symm_buffers( - chunk_tokens, local_vocab, dtype, device, tp_group - ) - - s_compute = torch.cuda.current_stream() - s_copy, s_out = _get_streams(device) - - events_out = [torch.cuda.Event() for _ in range(2)] - events_copy = [torch.cuda.Event() for _ in range(2)] - events_gather = [torch.cuda.Event() for _ in range(2)] - - chunks = [] - for start in range(0, num_tokens, chunk_tokens): - end = min(start + chunk_tokens, num_tokens) - chunks.append((start, end)) - num_chunks = len(chunks) - - if num_chunks > 0: - c0_start, c0_end = chunks[0] - with torch.cuda.stream(s_copy): - symm_logits[0][:c0_end - c0_start].copy_(logits_2d[c0_start:c0_end]) - events_copy[0].record(s_copy) - - for i in range(num_chunks): - b = i % 2 - nxt_b = (i + 1) % 2 - start, end = chunks[i] - current = end - start - local_tokens = current // world_size - - events_out[b].wait(s_compute) - events_copy[b].wait(s_compute) - hdl_logits.barrier(channel=b) - - seq_logits = torch.empty((local_tokens, world_size * local_vocab), dtype=dtype, device=device) - offset_elements = b * chunk_tokens * local_vocab - _get_ext().launch_gather( - logits_ptrs, offset_elements, seq_logits, - local_tokens, local_vocab, world_size, rank - ) - - events_gather[b].record(s_compute) - - if i + 1 < num_chunks: - nxt_start, nxt_end = chunks[i+1] - with torch.cuda.stream(s_copy): - events_gather[nxt_b].wait(s_copy) - symm_logits[nxt_b][:nxt_end - nxt_start].copy_(logits_2d[nxt_start:nxt_end]) - events_copy[nxt_b].record(s_copy) - - filtered, keep_mask = _apply_top_k_top_p(seq_logits, top_k=top_k, top_p=top_p) - filtered = filtered.contiguous() - if keep_mask is not None: - keep_mask = keep_mask.contiguous() - - probs = F.softmax(filtered.float(), dim=-1).contiguous() - - target_local = target_flat[start:end][rank * local_tokens : (rank + 1) * local_tokens] - grad_local = grad_flat[start:end][rank * local_tokens : (rank + 1) * local_tokens] - - _get_ext().launch_fused_scatter( - probs, target_local, grad_local, keep_mask, - grad_ptrs, offset_elements, - local_tokens, local_vocab, world_size, rank - ) - - hdl_grad.barrier(channel=b) - - event_scatter_done = torch.cuda.Event() - event_scatter_done.record(s_compute) - event_scatter_done.wait(s_out) - - with torch.cuda.stream(s_out): - grad_logits_2d[start:end].copy_(symm_grad[b][:current]) - events_out[b].record(s_out) - - s_compute.wait_stream(s_out) - return grad_logits_2d.reshape(batch, seq_len, local_vocab) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/86_distributed_sample_sort_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/86_distributed_sample_sort_cuda.py deleted file mode 100755 index e895861..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/86_distributed_sample_sort_cuda.py +++ /dev/null @@ -1,387 +0,0 @@ -from typing import List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - - -CUDA_SRC = r''' -#include -#include -#include -#include - -__global__ void extract_samples_kernel( - const __nv_bfloat16* __restrict__ sorted_local, - int64_t local_n, - int64_t* __restrict__ my_meta, - int sort_rank, - int active_count) -{ - int i = threadIdx.x; - if (i < active_count) { - float val = __int_as_float(0x7f800000); // +inf - int s_rank = -1; - int s_pos = -1; - if (sort_rank >= 0 && local_n > 0) { - int valid_count = min(active_count, (int)local_n); - if (i < valid_count) { - int64_t pos; - if (active_count < local_n) { - pos = ((int64_t)(i + 1) * local_n) / active_count - 1; - } else { - pos = i; - } - val = __bfloat162float(sorted_local[pos]); - s_rank = sort_rank; - s_pos = pos; - } - } - my_meta[10 + i*3 + 0] = __float_as_int(val); - my_meta[10 + i*3 + 1] = s_rank; - my_meta[10 + i*3 + 2] = s_pos; - } -} - -__global__ void compute_boundaries_kernel( - const __nv_bfloat16* __restrict__ sorted_local, - int64_t local_n, - const uint64_t* __restrict__ meta_ptrs, - int world_size, - int rank, - int sort_rank, - int active_count, - const int* __restrict__ active_ranks) -{ - if (threadIdx.x != 0) return; - - float s_vals[64]; - int s_ranks[64]; - int s_poses[64]; - int s_count = 0; - - // Gather all samples from active peers - for (int r = 0; r < active_count; ++r) { - int peer_rank = active_ranks[r]; - int64_t* peer_meta = (int64_t*)meta_ptrs[peer_rank]; - for (int i = 0; i < active_count; ++i) { - int s_r = peer_meta[10 + i*3 + 1]; - if (s_r >= 0) { - s_vals[s_count] = __int_as_float(peer_meta[10 + i*3 + 0]); - s_ranks[s_count] = s_r; - s_poses[s_count] = peer_meta[10 + i*3 + 2]; - s_count++; - } - } - } - - // Sort samples - for (int i = 1; i < s_count; ++i) { - float v = s_vals[i]; - int sr = s_ranks[i]; - int sp = s_poses[i]; - int j = i - 1; - while (j >= 0) { - bool swap = false; - if (s_vals[j] > v) swap = true; - else if (s_vals[j] == v && s_ranks[j] > sr) swap = true; - else if (s_vals[j] == v && s_ranks[j] == sr && s_poses[j] > sp) swap = true; - - if (swap) { - s_vals[j+1] = s_vals[j]; - s_ranks[j+1] = s_ranks[j]; - s_poses[j+1] = s_poses[j]; - j--; - } else { - break; - } - } - s_vals[j+1] = v; - s_ranks[j+1] = sr; - s_poses[j+1] = sp; - } - - // Pick splitters - float split_vals[8]; - int split_ranks[8]; - int split_poses[8]; - for (int k = 0; k < active_count - 1; ++k) { - int index = (k + 1) * s_count / active_count - 1; - if (index < 0) index = 0; - if (index >= s_count) index = s_count - 1; - split_vals[k] = s_vals[index]; - split_ranks[k] = s_ranks[index]; - split_poses[k] = s_poses[index]; - } - - // Binary search boundaries - int64_t boundaries[9]; - boundaries[0] = 0; - boundaries[active_count] = local_n; - - for (int k = 0; k < active_count - 1; ++k) { - float val = split_vals[k]; - int s_r = split_ranks[k]; - int s_p = split_poses[k]; - - int64_t low = 0; - int64_t high = local_n; - if (sort_rank > s_r) { - while (low < high) { - int64_t mid = low + (high - low) / 2; - if (__bfloat162float(sorted_local[mid]) < val) low = mid + 1; - else high = mid; - } - } else if (sort_rank < s_r) { - while (low < high) { - int64_t mid = low + (high - low) / 2; - if (__bfloat162float(sorted_local[mid]) <= val) low = mid + 1; - else high = mid; - } - } else { - low = s_p + 1; - } - boundaries[k + 1] = low; - } - - // Monotonicity fix - for (int k = 1; k <= active_count; ++k) { - if (boundaries[k] < boundaries[k-1]) boundaries[k] = boundaries[k-1]; - if (boundaries[k] > local_n) boundaries[k] = local_n; - } - - // Save boundaries internally - int64_t* my_meta = (int64_t*)meta_ptrs[rank]; - for (int k = 0; k <= active_count; ++k) { - my_meta[200 + k] = boundaries[k]; - } - - // Push send counts - for (int k = 0; k < active_count; ++k) { - int dest_rank = active_ranks[k]; - int64_t count = boundaries[k+1] - boundaries[k]; - int64_t* dest_meta = (int64_t*)meta_ptrs[dest_rank]; - dest_meta[400 + rank] = count; - } -} - -__global__ void compute_recv_offsets_kernel( - int64_t* __restrict__ my_meta, - int world_size) -{ - if (threadIdx.x != 0) return; - int64_t offset = 0; - for (int r = 0; r < world_size; ++r) { - my_meta[500 + r] = offset; - offset += my_meta[400 + r]; - } - my_meta[1] = offset; // Store total_recv -} - -__global__ void gather_merged_sizes_kernel( - const uint64_t* __restrict__ meta_ptrs, - int64_t* __restrict__ my_meta, - int world_size) -{ - if (threadIdx.x != 0) return; - for (int r = 0; r < world_size; ++r) { - my_meta[800 + r] = ((int64_t*)meta_ptrs[r])[1]; - } -} - -__global__ void push_a2a_payload_kernel( - const __nv_bfloat16* __restrict__ sorted_local, - const uint64_t* __restrict__ a2a_ptrs, - const uint64_t* __restrict__ meta_ptrs, - int rank, - int active_count, - const int* __restrict__ active_ranks) -{ - int bucket = blockIdx.y; - int dest_rank = active_ranks[bucket]; - - int64_t* my_meta = (int64_t*)meta_ptrs[rank]; - int64_t bucket_start = my_meta[200 + bucket]; - int64_t bucket_end = my_meta[200 + bucket + 1]; - int64_t count = bucket_end - bucket_start; - - int64_t dest_offset = ((int64_t*)meta_ptrs[dest_rank])[500 + rank]; - __nv_bfloat16* dest_buf = (__nv_bfloat16*)a2a_ptrs[dest_rank]; - - for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += gridDim.x * blockDim.x) { - dest_buf[dest_offset + i] = sorted_local[bucket_start + i]; - } -} - -__global__ void push_final_kernel( - const __nv_bfloat16* __restrict__ merged, - const uint64_t* __restrict__ final_ptrs, - int64_t bucket_start, - int64_t bucket_end, - int64_t target_start, - int64_t target_end, - int dest_rank) -{ - int64_t start = max(bucket_start, target_start); - int64_t end = min(bucket_end, target_end); - if (start >= end) return; - - int64_t count = end - start; - int64_t src_offset = start - bucket_start; - int64_t dst_offset = start - target_start; - - __nv_bfloat16* dest_buf = (__nv_bfloat16*)final_ptrs[dest_rank]; - - for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += gridDim.x * blockDim.x) { - dest_buf[dst_offset + i] = merged[src_offset + i]; - } -} - -void extract_samples(torch::Tensor sorted_local, int64_t local_n, torch::Tensor my_meta, int sort_rank, int active_count) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - extract_samples_kernel<<<1, 32, 0, stream>>>( - (__nv_bfloat16*)sorted_local.data_ptr(), local_n, my_meta.data_ptr(), sort_rank, active_count); -} - -void compute_boundaries(torch::Tensor sorted_local, int64_t local_n, torch::Tensor meta_ptrs, int world_size, int rank, int sort_rank, int active_count, torch::Tensor active_ranks) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - compute_boundaries_kernel<<<1, 32, 0, stream>>>( - (__nv_bfloat16*)sorted_local.data_ptr(), local_n, (const uint64_t*)meta_ptrs.data_ptr(), - world_size, rank, sort_rank, active_count, active_ranks.data_ptr()); -} - -void compute_recv_offsets(torch::Tensor my_meta, int world_size) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - compute_recv_offsets_kernel<<<1, 32, 0, stream>>>(my_meta.data_ptr(), world_size); -} - -void gather_merged_sizes(torch::Tensor meta_ptrs, torch::Tensor my_meta, int world_size) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - gather_merged_sizes_kernel<<<1, 32, 0, stream>>>((const uint64_t*)meta_ptrs.data_ptr(), my_meta.data_ptr(), world_size); -} - -void push_a2a_payload(torch::Tensor sorted_local, torch::Tensor a2a_ptrs, torch::Tensor meta_ptrs, int rank, int active_count, torch::Tensor active_ranks) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - dim3 grid(128, active_count); - dim3 block(256); - push_a2a_payload_kernel<<>>( - (__nv_bfloat16*)sorted_local.data_ptr(), (const uint64_t*)a2a_ptrs.data_ptr(), - (const uint64_t*)meta_ptrs.data_ptr(), rank, active_count, active_ranks.data_ptr()); -} - -void push_final(torch::Tensor merged, torch::Tensor final_ptrs, int64_t bucket_start, int64_t bucket_end, int64_t target_start, int64_t target_end, int dest_rank) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - push_final_kernel<<<128, 256, 0, stream>>>( - (__nv_bfloat16*)merged.data_ptr(), (const uint64_t*)final_ptrs.data_ptr(), - bucket_start, bucket_end, target_start, target_end, dest_rank); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("extract_samples", &extract_samples); - m.def("compute_boundaries", &compute_boundaries); - m.def("compute_recv_offsets", &compute_recv_offsets); - m.def("gather_merged_sizes", &gather_merged_sizes); - m.def("push_a2a_payload", &push_a2a_payload); - m.def("push_final", &push_final); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("symm_sample_sort_ext", CUDA_SRC) - return _ext - -_symm_cache = {} -def get_symm_buffer(name: str, min_size: int, dtype: torch.dtype, device: torch.device): - if name in _symm_cache: - buf, hdl = _symm_cache[name] - if buf.numel() >= min_size and buf.dtype == dtype: - return buf, hdl - alloc_size = max(min_size, 1024 * 1024) if name != "meta" else min_size - buf = symm_mem.empty(alloc_size, dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - _symm_cache[name] = (buf, hdl) - return buf, hdl - - -@torch.no_grad() -def solution(local_shard: torch.Tensor, group: Optional[dist.ProcessGroup] = None) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - device = local_shard.device - local_n = local_shard.numel() - - # Initial sizes exchange via tight gather buffer - size_t = torch.tensor([local_n], dtype=torch.long, device=device) - gathered_sizes_t = torch.empty((world_size,), dtype=torch.long, device=device) - dist.all_gather_into_tensor(gathered_sizes_t, size_t, group=group) - sizes = gathered_sizes_t.cpu().tolist() - - total_N = sum(sizes) - if total_N == 0: - return local_shard.new_empty(0) - - active_ranks = [i for i, s in enumerate(sizes) if s > 0] - sort_rank = active_ranks.index(rank) if rank in active_ranks else -1 - active_count = len(active_ranks) - - # Pre-allocate fully overlapping max-span device-side buffers - meta_buf, meta_hdl = get_symm_buffer("meta", 1024, torch.int64, device) - a2a_buf, a2a_hdl = get_symm_buffer("a2a", total_N, torch.bfloat16, device) - - base = total_N // world_size - extra = total_N % world_size - my_target_size = base + (1 if rank < extra else 0) - final_buf, final_hdl = get_symm_buffer("final", base + 1, torch.bfloat16, device) - - ext = _get_ext() - sorted_local = local_shard.sort().values - active_ranks_t = torch.tensor(active_ranks, dtype=torch.int32, device=device) - meta_ptrs = torch.tensor(meta_hdl.buffer_ptrs, dtype=torch.int64, device=device) - a2a_ptrs = torch.tensor(a2a_hdl.buffer_ptrs, dtype=torch.int64, device=device) - final_ptrs = torch.tensor(final_hdl.buffer_ptrs, dtype=torch.int64, device=device) - - # Initialize / synchronize state buffers for deterministic offsets and exchanges - meta_buf.zero_() - dist.barrier(group) - - ext.extract_samples(sorted_local, local_n, meta_buf, sort_rank, active_count) - dist.barrier(group) - - ext.compute_boundaries(sorted_local, local_n, meta_ptrs, world_size, rank, sort_rank, active_count, active_ranks_t) - dist.barrier(group) - - ext.compute_recv_offsets(meta_buf, world_size) - dist.barrier(group) - - # Overlap exact destination footprint sizes retrieval with asynchronous payload exchange - ext.gather_merged_sizes(meta_ptrs, meta_buf, world_size) - if active_count > 0 and local_n > 0: - ext.push_a2a_payload(sorted_local, a2a_ptrs, meta_ptrs, rank, active_count, active_ranks_t) - - dist.barrier(group) - meta_cpu = meta_buf.cpu() - total_recv = int(meta_cpu[1].item()) - merged_sizes = meta_cpu[800 : 800 + world_size].tolist() - - # Intermediate variable sizes merged sort - if total_recv > 0: - merged = a2a_buf[:total_recv].sort().values - else: - merged = torch.empty(0, dtype=torch.bfloat16, device=device) - - bucket_start = sum(merged_sizes[:rank]) - bucket_end = bucket_start + merged.numel() - - # Exact redistribution kernel sweeps exact offset bounds concurrently directly via symmetric P2P writes - for dest in range(world_size): - target_start = dest * base + min(dest, extra) - target_end = target_start + base + (1 if dest < extra else 0) - ext.push_final(merged, final_ptrs, bucket_start, bucket_end, target_start, target_end, dest) - - dist.barrier(group) - return final_buf[:my_target_size].clone() \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/87_tp_muon_orthogonalization_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/87_tp_muon_orthogonalization_cuda.py deleted file mode 100755 index 597cade..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/87_tp_muon_orthogonalization_cuda.py +++ /dev/null @@ -1,396 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from typing import Optional, Sequence -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// --------------------------------------------------------------------------- -// Blockwise Barriers (Device-Side Synchronization) -// --------------------------------------------------------------------------- -__device__ __forceinline__ void send_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile("atom.global.relaxed.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_relaxed(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile("atom.global.sys.relaxed.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1u); -} - -__device__ __forceinline__ void send_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile("atom.global.release.sys.cas.b32 %0, [%1], 0, 1;" : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 0u); -} - -__device__ __forceinline__ void wait_signal_acq_rel(uint32_t* addr) { - uint32_t tmp; - do { - asm volatile("atom.global.sys.acquire.cas.b32 %0, [%1], 1, 0;" : "=r"(tmp) : "l"(addr) : "memory"); - } while (tmp != 1u); -} - -__device__ void blockwise_barrier_relaxed(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast(remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast(local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_relaxed(send_addr); - wait_signal_relaxed(wait_addr); -} - -__device__ void blockwise_barrier_acq_rel(const uint64_t* __restrict__ signal_pad_ptrs, uint64_t block_id, int rank, int world_size) { - unsigned int flat_tid = threadIdx.x; - if (flat_tid >= (unsigned int)world_size) return; - uint64_t local_base = signal_pad_ptrs[rank]; - uint64_t remote_base = signal_pad_ptrs[flat_tid]; - uint32_t* send_addr = reinterpret_cast(remote_base + block_id * (uint64_t)world_size + (uint64_t)rank); - uint32_t* wait_addr = reinterpret_cast(local_base + block_id * (uint64_t)world_size + (uint64_t)flat_tid); - send_signal_acq_rel(send_addr); - wait_signal_acq_rel(wait_addr); -} - -// --------------------------------------------------------------------------- -// Multimem operations & Kernels -// --------------------------------------------------------------------------- -__device__ __forceinline__ void multimem_ld_reduce_bf16x4(const uint64_t* addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) { - asm volatile("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "l"(addr) : "memory"); -} - -__device__ __forceinline__ void multimem_st_bf16x4(const uint64_t* addr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1, %2, %3, %4};" : : "l"(addr), "r"(x), "r"(y), "r"(z), "r"(w) : "memory"); -} - -__global__ void multimem_allreduce_bf16_kernel( - uint64_t multicast_base, - const uint64_t* __restrict__ signal_pad_ptrs, - int64_t numel_128, - int world_size, - int rank, - int block_stride -) { - const uint64_t block_id = static_cast(blockIdx.x); - blockwise_barrier_relaxed(signal_pad_ptrs, block_id, rank, world_size); - __syncthreads(); - - const int64_t numel_per_rank = (numel_128 + (int64_t)world_size - 1) / (int64_t)world_size; - const int num_programs = gridDim.x; - const int tid = threadIdx.x; - - for (int64_t block_start = (int64_t)block_id * (int64_t)block_stride; block_start < numel_per_rank; block_start += (int64_t)num_programs * (int64_t)block_stride) { - const int64_t offsets = block_start + (int64_t)tid; - if (offsets >= numel_per_rank) continue; - const int64_t idx = (int64_t)rank * numel_per_rank + offsets; - uint64_t* ptrs = reinterpret_cast(multicast_base) + idx * 2; - uint32_t x, y, z, w; - multimem_ld_reduce_bf16x4(ptrs, x, y, z, w); - multimem_st_bf16x4(ptrs, x, y, z, w); - } - __syncthreads(); - blockwise_barrier_acq_rel(signal_pad_ptrs, block_id, rank, world_size); -} - -void launch_multimem_allreduce_bf16( - uint64_t multicast_ptr, - torch::Tensor signal_pad_ptrs_tensor, - int64_t numel, - int world_size, - int rank, - int num_blocks, - int block_size, - int block_stride -) { - const uint64_t* d_signal = reinterpret_cast(signal_pad_ptrs_tensor.data_ptr()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - multimem_allreduce_bf16_kernel<<>>(multicast_ptr, d_signal, numel, world_size, rank, block_stride); -} - -// --------------------------------------------------------------------------- -// Fallback & Standard All-Reduce Kernels -// --------------------------------------------------------------------------- -__global__ void allreduce_bf16_kernel(const long long* __restrict__ ptrs, __nv_bfloat16* __restrict__ out, int world_size, int64_t n) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - sum += __bfloat162float(((const __nv_bfloat16*)ptrs[r])[idx]); - } - out[idx] = __float2bfloat16(sum); - } -} - -__global__ void allreduce_f32_kernel(const long long* __restrict__ ptrs, float* __restrict__ out, int world_size, int64_t n) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - float sum = 0.0f; - #pragma unroll - for (int r = 0; r < world_size; ++r) { - sum += ((const float*)ptrs[r])[idx]; - } - out[idx] = sum; - } -} - -void launch_allreduce(torch::Tensor ptrs_tensor, torch::Tensor out, int64_t n, int dtype_enum) { - int world_size = ptrs_tensor.size(0); - const long long* d_ptrs = (const long long*)ptrs_tensor.data_ptr(); - int threads = 512; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (dtype_enum == 0) { - allreduce_bf16_kernel<<>>(d_ptrs, (__nv_bfloat16*)out.data_ptr(), world_size, n); - } else { - allreduce_f32_kernel<<>>(d_ptrs, out.data_ptr(), world_size, n); - } -} - -// --------------------------------------------------------------------------- -// Fused Normalization & Output Cast Kernels -// --------------------------------------------------------------------------- -__global__ void scale_and_cast_bf16_kernel(const float* __restrict__ x, const float* __restrict__ norm_sq, __nv_bfloat16* __restrict__ out, float eps, int64_t n) { - __shared__ float s_scale; - if (threadIdx.x == 0) { - float n_sq = *norm_sq; - s_scale = rsqrtf(n_sq < eps ? eps : n_sq); - } - __syncthreads(); - float scale = s_scale; - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - out[idx] = __float2bfloat16(x[idx] * scale); - } -} - -void launch_scale_and_cast_bf16(torch::Tensor x, torch::Tensor norm_sq, torch::Tensor out, float eps) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - scale_and_cast_bf16_kernel<<>>(x.data_ptr(), norm_sq.data_ptr(), reinterpret_cast<__nv_bfloat16*>(out.data_ptr()), eps, n); -} - -__global__ void cast_to_f32_kernel(const __nv_bfloat16* __restrict__ x, float* __restrict__ out, int64_t n) { - int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - for (; idx < n; idx += (int64_t)gridDim.x * blockDim.x) { - out[idx] = __bfloat162float(x[idx]); - } -} - -__global__ void cast_to_f32_and_transpose_kernel(const __nv_bfloat16* __restrict__ x, float* __restrict__ out, int64_t rows, int64_t cols) { - __shared__ float tile[32][33]; - int x_idx = blockIdx.x * 32 + threadIdx.x; - int y_idx = blockIdx.y * 32 + threadIdx.y; - if (x_idx < cols && y_idx < rows) { - tile[threadIdx.y][threadIdx.x] = __bfloat162float(x[y_idx * cols + x_idx]); - } - __syncthreads(); - x_idx = blockIdx.y * 32 + threadIdx.x; - y_idx = blockIdx.x * 32 + threadIdx.y; - if (x_idx < rows && y_idx < cols) { - out[y_idx * rows + x_idx] = tile[threadIdx.x][threadIdx.y]; - } -} - -void launch_cast_to_f32(torch::Tensor x, torch::Tensor out, bool transpose) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (!transpose) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - if (blocks > 65535) blocks = 65535; - cast_to_f32_kernel<<>>(reinterpret_cast(x.data_ptr()), out.data_ptr(), n); - } else { - int64_t rows = x.size(0); - int64_t cols = x.size(1); - dim3 threads(32, 32); - dim3 blocks((cols + 31) / 32, (rows + 31) / 32); - cast_to_f32_and_transpose_kernel<<>>(reinterpret_cast(x.data_ptr()), out.data_ptr(), rows, cols); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_multimem_allreduce_bf16", &launch_multimem_allreduce_bf16); - m.def("launch_allreduce", &launch_allreduce); - m.def("launch_scale_and_cast_bf16", &launch_scale_and_cast_bf16); - m.def("launch_cast_to_f32", &launch_cast_to_f32); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("muon_tp_orthogonalization_ext", CUDA_SRC) - return _ext - -_resource_cache = {} -def _get_resources(shape, dtype, device, group): - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, group) - out = torch.empty(shape, device=device, dtype=dtype) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - res = (buf, hdl, out, ptrs_tensor) - _resource_cache[key] = res - return res - -WARP_SIZE = 32 -MAX_NUM_BLOCKS = 4 -MAX_BLOCK_SIZE = 1024 -BYTES_PER_THREAD = 16 - -def _multimem_launch_config(numel: int, world_size: int) -> tuple[int, int, int]: - numel_per_thread = BYTES_PER_THREAD // 2 - num_threads = (numel // numel_per_thread + world_size - 1) // world_size - if num_threads < MAX_BLOCK_SIZE: - block_size = 1 - while block_size < num_threads: block_size *= 2 - num_blocks = 1 - else: - block_size = MAX_BLOCK_SIZE - num_blocks = min((num_threads + MAX_BLOCK_SIZE - 1) // MAX_BLOCK_SIZE, MAX_NUM_BLOCKS) - return num_blocks, block_size, block_size - -_COEFFICIENTS: dict[str, Sequence[tuple[float, float, float]]] = { - "simple": ((3.4445, -4.7750, 2.0315),), - "quintic": ( - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ), - "polar_express": ( - (8.2051, -22.9019, 16.4607), - (4.0664, -2.8612, 0.5184), - (3.9096, -2.8234, 0.5250), - (3.2856, -2.4647, 0.5074), - (2.2779, -1.6447, 0.4162), - (1.8726, -1.2307, 0.3585), - (1.8564, -1.2132, 0.3568), - (1.8750, -1.2500, 0.3750), - ), - "aol": ( - (4.0098, -7.0585, 2.4635), - (3.4585, -5.5479, 2.5959), - (2.7573, -3.2939, 1.4254), - (2.7215, -3.0494, 1.3169), - ), -} - -def _coefficient_at(coefficients: Sequence[tuple[float, float, float]], step: int) -> tuple[float, float, float]: - return coefficients[step % len(coefficients)] - -@torch.no_grad() -def solution( - x: torch.Tensor, - steps: int = 5, - coefficient_type: str = "quintic", - partition_dim: int = 1, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - assert x.ndim == 2 and x.dtype == torch.float32 - coefficients = _COEFFICIENTS[coefficient_type] - - if partition_dim == 0: - x_work = x.mT.contiguous() - elif partition_dim == 1: - x_work = x.contiguous() - - N, M = x_work.shape - _get_ext() # JIT compile - - buf_norm, hdl_norm, out_norm, ptrs_norm = _get_resources((1,), torch.float32, x.device, group) - buf_gram, hdl_gram, out_gram, ptrs_gram = _get_resources((N, N), torch.bfloat16, x.device, group) - - compute_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream() - - x_flat = x_work.flatten() - buf_norm.copy_(torch.dot(x_flat, x_flat)) - compute_stream.synchronize() - hdl_norm.barrier(channel=0) - - _get_ext().launch_allreduce(ptrs_norm, out_norm, 1, 1) - - x_work_bf16 = torch.empty((N, M), dtype=torch.bfloat16, device=x.device) - _get_ext().launch_scale_and_cast_bf16(x_work, out_norm, x_work_bf16, 1e-7) - - update = torch.empty((N, N), dtype=torch.bfloat16, device=x.device) - x_work_next = torch.empty_like(x_work_bf16) - - NUM_CHUNKS = 4 - chunk_size = N // NUM_CHUNKS - use_multimem = True - - if (N * N) % 8 != 0 or not hasattr(hdl_gram, 'multicast_ptr') or hdl_gram.multicast_ptr is None or N % NUM_CHUNKS != 0 or (chunk_size * N) % 8 != 0: - use_multimem = False - NUM_CHUNKS = 1 - chunk_size = N - - events_comp = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - events_comm = [torch.cuda.Event() for _ in range(NUM_CHUNKS)] - - for step in range(steps): - a, b, c = _coefficient_at(coefficients, step) - - if use_multimem: - for i in range(NUM_CHUNKS): - start = i * chunk_size - end = start + chunk_size - torch.matmul(x_work_bf16[start:end], x_work_bf16.mT, out=buf_gram[start:end]) - events_comp[i].record(compute_stream) - - comm_stream.wait_event(events_comp[i]) - with torch.cuda.stream(comm_stream): - numel_chunk = chunk_size * N - numel_128 = numel_chunk // 8 - num_blocks, block_size, block_stride = _multimem_launch_config(numel_chunk, hdl_gram.world_size) - chunk_multicast_ptr = int(hdl_gram.multicast_ptr) + (start * N * 2) - - _get_ext().launch_multimem_allreduce_bf16( - chunk_multicast_ptr, hdl_gram.signal_pad_ptrs_dev, numel_128, - hdl_gram.world_size, hdl_gram.rank, num_blocks, block_size, block_stride - ) - events_comm[i].record(comm_stream) - - for i in range(NUM_CHUNKS): - compute_stream.wait_event(events_comm[i]) - result_gram = buf_gram - else: - torch.matmul(x_work_bf16, x_work_bf16.mT, out=buf_gram) - compute_stream.synchronize() - hdl_gram.barrier(channel=0) - _get_ext().launch_allreduce(ptrs_gram, out_gram, N * N, 0) - compute_stream.synchronize() - hdl_gram.barrier(channel=0) - result_gram = out_gram - - torch.addmm(result_gram, result_gram, result_gram, beta=b, alpha=c, out=update) - torch.addmm(x_work_bf16, update, x_work_bf16, beta=a, alpha=1.0, out=x_work_next) - - x_work_bf16, x_work_next = x_work_next, x_work_bf16 - - out_f32 = torch.empty_like(x) - _get_ext().launch_cast_to_f32(x_work_bf16, out_f32, partition_dim == 0) - return out_f32 \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/88_conv2d_boundary_exchange_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/88_conv2d_boundary_exchange_cuda.py deleted file mode 100755 index ca817bf..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/88_conv2d_boundary_exchange_cuda.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Strategy: -- **Device-Side Communication**: We replace `dist.all_gather` with a custom P2P boundary exchange using `torch.distributed._symmetric_memory`. Each rank exposes only its top and bottom boundary rows (size `[2, B, C, padding, W]`) in a symmetric memory buffer. -- **Zero-Copy & UVA**: We use UVA device pointers to load boundary data directly from adjacent ranks into our local `padded_x` buffer, avoiding any host-side collectives or intermediate buffers. -- **Compute-Communication Overlap**: The local extraction of boundaries and the copy of the core local tensor into `padded_x` are fused into a single asynchronous kernel. This kernel executes immediately, overlapping the data preparation with the barrier wait (`hdl.barrier()`), effectively hiding the local setup latency. Finally, a single optimized `F.conv2d` call runs on the continuous patched buffer. -""" - -from typing import Optional - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.distributed._symmetric_memory as symm_mem - -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include -#include - -template -__device__ __forceinline__ T get_zero(); - -template<> -__device__ __forceinline__ float get_zero() { return 0.0f; } - -template<> -__device__ __forceinline__ __half get_zero<__half>() { - return __float2half(0.0f); -} - -template<> -__device__ __forceinline__ __nv_bfloat16 get_zero<__nv_bfloat16>() { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __float2bfloat16(0.0f); -#else - unsigned short val = 0; - return *reinterpret_cast<__nv_bfloat16*>(&val); -#endif -} - -template -__global__ void pack_and_pad_kernel( - const T* __restrict__ x, - T* __restrict__ symm_buf, - T* __restrict__ padded_x, - int B, int C, int H, int W, int boundary, - int64_t numel_x -) { - int64_t H_padded = H + 2 * boundary; - for (int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - idx < numel_x; - idx += (int64_t)gridDim.x * blockDim.x) { - - int w = idx % W; - int h = (idx / W) % H; - int c = (idx / (W * H)) % C; - int b = idx / (W * H * C); - - int64_t padded_h = h + boundary; - int64_t padded_idx = ((int64_t)b * C * H_padded * W) + - ((int64_t)c * H_padded * W) + - (padded_h * W) + w; - - T val = x[idx]; - padded_x[padded_idx] = val; - - if (h < boundary) { - int64_t symm_idx = ((int64_t)b * C * boundary * W) + - ((int64_t)c * boundary * W) + - (h * W) + w; - symm_buf[symm_idx] = val; - } else if (h >= H - boundary) { - int h_b = h - (H - boundary); - int64_t symm_idx = ((int64_t)1 * B * C * boundary * W) + - ((int64_t)b * C * boundary * W) + - ((int64_t)c * boundary * W) + - (h_b * W) + w; - symm_buf[symm_idx] = val; - } - } -} - -template -__global__ void unpack_peers_kernel( - T* __restrict__ padded_x, - const T* __restrict__ peer_top_buf, - const T* __restrict__ peer_bottom_buf, - int B, int C, int H_padded, int W, int boundary, - int rank, int world_size, - int64_t total_boundary_numel -) { - int64_t numel_boundary = total_boundary_numel / 2; - for (int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - idx < total_boundary_numel; - idx += (int64_t)gridDim.x * blockDim.x) { - - bool is_top = idx < numel_boundary; - int64_t bnd_idx = is_top ? idx : (idx - numel_boundary); - - int w = bnd_idx % W; - int h = (bnd_idx / W) % boundary; - int c = (bnd_idx / (W * boundary)) % C; - int b = bnd_idx / (W * boundary * C); - - if (is_top) { - int64_t out_idx = ((int64_t)b * C * H_padded * W) + - ((int64_t)c * H_padded * W) + - (h * W) + w; - if (rank > 0 && peer_top_buf != nullptr) { - padded_x[out_idx] = peer_top_buf[bnd_idx]; - } else { - padded_x[out_idx] = get_zero(); - } - } else { - int h_out = H_padded - boundary + h; - int64_t out_idx = ((int64_t)b * C * H_padded * W) + - ((int64_t)c * H_padded * W) + - (h_out * W) + w; - if (rank < world_size - 1 && peer_bottom_buf != nullptr) { - padded_x[out_idx] = peer_bottom_buf[bnd_idx]; - } else { - padded_x[out_idx] = get_zero(); - } - } - } -} - -void launch_pack_and_pad( - torch::Tensor x, - torch::Tensor symm_buf, - torch::Tensor padded_x, - int boundary -) { - int B = x.size(0); - int C = x.size(1); - int H = x.size(2); - int W = x.size(3); - - int64_t numel_x = x.numel(); - int threads = 256; - int blocks = std::min((numel_x + threads - 1) / threads, 65535); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.dtype() == torch::kBFloat16) { - pack_and_pad_kernel<__nv_bfloat16><<>>( - reinterpret_cast(x.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(symm_buf.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(padded_x.data_ptr()), - B, C, H, W, boundary, numel_x - ); - } else if (x.dtype() == torch::kFloat32) { - pack_and_pad_kernel<<>>( - x.data_ptr(), - symm_buf.data_ptr(), - padded_x.data_ptr(), - B, C, H, W, boundary, numel_x - ); - } else if (x.dtype() == torch::kFloat16) { - pack_and_pad_kernel<__half><<>>( - reinterpret_cast(x.data_ptr()), - reinterpret_cast<__half*>(symm_buf.data_ptr()), - reinterpret_cast<__half*>(padded_x.data_ptr()), - B, C, H, W, boundary, numel_x - ); - } -} - -void launch_unpack_peers( - torch::Tensor padded_x, - int64_t peer_top_ptr, - int64_t peer_bottom_ptr, - int boundary, - int rank, - int world_size -) { - int B = padded_x.size(0); - int C = padded_x.size(1); - int H_padded = padded_x.size(2); - int W = padded_x.size(3); - - int64_t numel_boundary = (int64_t)B * C * boundary * W; - int64_t total_boundary_numel = numel_boundary * 2; - int threads = 256; - int blocks = std::min((total_boundary_numel + threads - 1) / threads, 65535); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (padded_x.dtype() == torch::kBFloat16) { - unpack_peers_kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(padded_x.data_ptr()), - reinterpret_cast(peer_top_ptr), - reinterpret_cast(peer_bottom_ptr), - B, C, H_padded, W, boundary, rank, world_size, total_boundary_numel - ); - } else if (padded_x.dtype() == torch::kFloat32) { - unpack_peers_kernel<<>>( - padded_x.data_ptr(), - reinterpret_cast(peer_top_ptr), - reinterpret_cast(peer_bottom_ptr), - B, C, H_padded, W, boundary, rank, world_size, total_boundary_numel - ); - } else if (padded_x.dtype() == torch::kFloat16) { - unpack_peers_kernel<__half><<>>( - reinterpret_cast<__half*>(padded_x.data_ptr()), - reinterpret_cast(peer_top_ptr), - reinterpret_cast(peer_bottom_ptr), - B, C, H_padded, W, boundary, rank, world_size, total_boundary_numel - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pack_and_pad", &launch_pack_and_pad, "Pack boundary and init local tensor"); - m.def("launch_unpack_peers", &launch_unpack_peers, "Unpack UVA boundary rows"); -} -''' - -_ext = None - - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("conv2d_boundary_cuda", CUDA_SRC) - return _ext - - -_symm_cache = {} - - -def _get_symm_state(B, C, boundary, W, dtype, device, group): - key = (B, C, boundary, W, dtype, device, id(group)) - if key in _symm_cache: - return _symm_cache[key] - - # 2 buffers: index 0 for top boundary to send, index 1 for bottom boundary to send - buf = symm_mem.empty((2, B, C, boundary, W), dtype=dtype, device=device) - hdl = symm_mem.rendezvous(buf, group) - _symm_cache[key] = (buf, hdl) - return buf, hdl - - -@torch.no_grad() -def solution( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - stride: int = 1, - padding: int = 1, - group: Optional[dist.ProcessGroup] = None, -) -> torch.Tensor: - group = group or dist.group.WORLD - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - boundary = int(padding) - - if boundary == 0 or world_size == 1: - return F.conv2d(x, weight, bias, stride=stride, padding=padding) - - x = x.contiguous() - - if rank == 0: - _get_ext() - dist.barrier(group=group) - - ext = _get_ext() - - B, C, H, W = x.shape - H_padded = H + 2 * boundary - - padded_x = torch.empty((B, C, H_padded, W), dtype=x.dtype, device=x.device) - symm_buf, hdl = _get_symm_state(B, C, boundary, W, x.dtype, x.device, group) - - # Fused operation: asynchronously slice/copy the local tensor and prepare peer payloads - ext.launch_pack_and_pad(x, symm_buf, padded_x, boundary) - - # Ensure memory writes to the symmetric buffers are globally visible - hdl.barrier(channel=0) - - element_size = x.element_size() - # Offset pointer by the full size of one boundary to access the bottom boundary - offset = B * C * boundary * W * element_size - - peer_top_ptr = 0 - peer_bottom_ptr = 0 - - if rank > 0: - # rank - 1's bottom boundary (index 1) - peer_top_ptr = int(hdl.buffer_ptrs[rank - 1]) + offset - if rank < world_size - 1: - # rank + 1's top boundary (index 0) - peer_bottom_ptr = int(hdl.buffer_ptrs[rank + 1]) - - ext.launch_unpack_peers( - padded_x, - peer_top_ptr, - peer_bottom_ptr, - boundary, - rank, - world_size - ) - - return F.conv2d(padded_x, weight, bias, stride=stride, padding=(0, padding)) \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_cuda.py deleted file mode 100755 index b8aadb6..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/8_alltoall_cuda.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Strategy: -- **Device-Side Communication:** Instead of NCCL, we use `torch.distributed._symmetric_memory` to allocate UVA-accessible device memory. A custom CUDA kernel performs parallel "PULL" operations directly over NVLink, allowing each rank to independently fetch its designated data chunk from peers' symmetric buffers without host synchronization. -- **Maximized Memory Bandwidth:** The kernel dynamically verifies memory alignment and degrades gracefully, using 128-bit (`uint4`) vectorized loads and stores whenever chunks are 16-byte aligned. This maximizes NVLink and global memory bus utilization. -- **Compute-Communication Overlap & Stream Semantics:** By utilizing `symm_mem.rendezvous.barrier()`, the synchronization is fully stream-ordered. The entire operation—local copy, peer synchronization, and PULL data movement—is enqueued asynchronously, allowing the host to immediately return and the CUDA scheduler to execute the collective efficiently while the CPU proceeds. -""" - -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -__global__ void pull_all_to_all_kernel( - const uint64_t* __restrict__ peer_ptrs, - uint8_t* __restrict__ out, - int64_t chunk_bytes, - int world_size, - int rank -) { - // Each block in the Y dimension handles reading from a specific peer - int peer = blockIdx.y; - int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; - - const uint8_t* peer_buf = reinterpret_cast(peer_ptrs[peer]); - int64_t peer_start = (int64_t)rank * chunk_bytes; - int64_t out_start = (int64_t)peer * chunk_bytes; - - // Attempt the widest possible vectorized load/store based on alignment - if ((reinterpret_cast(peer_buf) % 16 == 0) && - (reinterpret_cast(out) % 16 == 0) && - (peer_start % 16 == 0) && - (out_start % 16 == 0) && - (chunk_bytes % 16 == 0)) - { - int64_t chunk_128b = chunk_bytes / 16; - const uint4* peer_buf_128 = reinterpret_cast(peer_buf + peer_start); - uint4* out_buf_128 = reinterpret_cast(out + out_start); - - for (int64_t offset = tid; offset < chunk_128b; offset += (int64_t)gridDim.x * blockDim.x) { - out_buf_128[offset] = peer_buf_128[offset]; - } - } - else if ((reinterpret_cast(peer_buf) % 8 == 0) && - (reinterpret_cast(out) % 8 == 0) && - (peer_start % 8 == 0) && - (out_start % 8 == 0) && - (chunk_bytes % 8 == 0)) - { - int64_t chunk_64b = chunk_bytes / 8; - const uint64_t* peer_buf_64 = reinterpret_cast(peer_buf + peer_start); - uint64_t* out_buf_64 = reinterpret_cast(out + out_start); - - for (int64_t offset = tid; offset < chunk_64b; offset += (int64_t)gridDim.x * blockDim.x) { - out_buf_64[offset] = peer_buf_64[offset]; - } - } - else if ((reinterpret_cast(peer_buf) % 4 == 0) && - (reinterpret_cast(out) % 4 == 0) && - (peer_start % 4 == 0) && - (out_start % 4 == 0) && - (chunk_bytes % 4 == 0)) - { - int64_t chunk_32b = chunk_bytes / 4; - const uint32_t* peer_buf_32 = reinterpret_cast(peer_buf + peer_start); - uint32_t* out_buf_32 = reinterpret_cast(out + out_start); - - for (int64_t offset = tid; offset < chunk_32b; offset += (int64_t)gridDim.x * blockDim.x) { - out_buf_32[offset] = peer_buf_32[offset]; - } - } - else if ((reinterpret_cast(peer_buf) % 2 == 0) && - (reinterpret_cast(out) % 2 == 0) && - (peer_start % 2 == 0) && - (out_start % 2 == 0) && - (chunk_bytes % 2 == 0)) - { - int64_t chunk_16b = chunk_bytes / 2; - const uint16_t* peer_buf_16 = reinterpret_cast(peer_buf + peer_start); - uint16_t* out_buf_16 = reinterpret_cast(out + out_start); - - for (int64_t offset = tid; offset < chunk_16b; offset += (int64_t)gridDim.x * blockDim.x) { - out_buf_16[offset] = peer_buf_16[offset]; - } - } - else { - for (int64_t offset = tid; offset < chunk_bytes; offset += (int64_t)gridDim.x * blockDim.x) { - out[out_start + offset] = peer_buf[peer_start + offset]; - } - } -} - -void launch_pull_all_to_all( - torch::Tensor peer_ptrs_tensor, - torch::Tensor out, - int64_t chunk_bytes, - int world_size, - int rank -) { - const uint64_t* d_ptrs = reinterpret_cast(peer_ptrs_tensor.data_ptr()); - uint8_t* d_out = reinterpret_cast(out.data_ptr()); - - int threads = 256; - int max_blocks_x = 1024; - - int64_t max_elements = chunk_bytes; - if (chunk_bytes % 16 == 0) max_elements = chunk_bytes / 16; - else if (chunk_bytes % 8 == 0) max_elements = chunk_bytes / 8; - else if (chunk_bytes % 4 == 0) max_elements = chunk_bytes / 4; - else if (chunk_bytes % 2 == 0) max_elements = chunk_bytes / 2; - - int blocks_x = std::min((int)((max_elements + threads - 1) / threads), max_blocks_x); - if (blocks_x < 1) blocks_x = 1; - - dim3 blocks(blocks_x, world_size, 1); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - pull_all_to_all_kernel<<>>( - d_ptrs, d_out, chunk_bytes, world_size, rank - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_pull_all_to_all", &launch_pull_all_to_all, "Pull-based All-to-All P2P kernel"); -} -''' - -_ext = None - -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("all_to_all_pull_ext", CUDA_SRC) - return _ext - -_resource_cache = {} - -def _get_resources(shape, dtype, device): - """Caches symmetric memory buffers and rendezvous handles to prevent reallocation overhead.""" - key = (shape, dtype, device) - if key in _resource_cache: - return _resource_cache[key] - - buf = symm_mem.empty(shape, device=device, dtype=dtype) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - res = (buf, hdl, ptrs_tensor) - _resource_cache[key] = res - return res - -@torch.no_grad() -def solution( - tensor: torch.Tensor, -) -> torch.Tensor: - assert dist.is_initialized(), "torch.distributed must be initialized" - assert tensor.is_contiguous(), "Input tensor must be contiguous" - - world_size = dist.get_world_size() - rank = dist.get_rank() - - assert tensor.shape[0] == world_size, \ - f"First dimension ({tensor.shape[0]}) must equal world_size ({world_size})" - - buf, hdl, ptrs_tensor = _get_resources(tensor.shape, tensor.dtype, tensor.device) - - # Copy local data into symmetric memory so it's accessible to peers - buf.copy_(tensor) - - # Enqueue stream-ordered barrier: wait for all peers to finish writing their buffers - hdl.barrier(channel=0) - - # Calculate bytes per rank chunk - chunk_bytes = (tensor.numel() // world_size) * tensor.element_size() - out = torch.empty_like(tensor) - - # PULL execution: read slices asynchronously over NVLink/UVA directly to local output - _get_ext().launch_pull_all_to_all(ptrs_tensor, out, chunk_bytes, world_size, rank) - - # Enqueue stream-ordered barrier: prevent overwriting the buffer in subsequent calls - # before peers have safely concluded pulling - hdl.barrier(channel=0) - - return out \ No newline at end of file diff --git a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_cuda.py b/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_cuda.py deleted file mode 100755 index ba2873d..0000000 --- a/solutions_cuda_bf16_h100_8_google_gemini-3-pro-preview/9_layernorm_backward_cuda.py +++ /dev/null @@ -1,328 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem -from utils.cuda_helpers import compile_cuda_extension - -CUDA_SRC = r''' -#include -#include -#include -#include -#include - -// ---------------------------------------------------------------------------- -// Type conversions mapping to float and float2 -// ---------------------------------------------------------------------------- - -template __device__ __forceinline__ float to_float(T v); -template <> __device__ __forceinline__ float to_float(__nv_bfloat16 v) { return __bfloat162float(v); } -template <> __device__ __forceinline__ float to_float(__half v) { return __half2float(v); } -template <> __device__ __forceinline__ float to_float(float v) { return v; } - -template __device__ __forceinline__ T from_float(float v); -template <> __device__ __forceinline__ __nv_bfloat16 from_float(float v) { return __float2bfloat16(v); } -template <> __device__ __forceinline__ __half from_float(float v) { return __float2half(v); } -template <> __device__ __forceinline__ float from_float(float v) { return v; } - -template __device__ __forceinline__ float2 cvt_float2(T2 v); -template <> __device__ __forceinline__ float2 cvt_float2<__nv_bfloat16, __nv_bfloat162>(__nv_bfloat162 v) { return __bfloat1622float2(v); } -template <> __device__ __forceinline__ float2 cvt_float2<__half, __half2>(__half2 v) { return __half22float2(v); } - -// ---------------------------------------------------------------------------- -// Kernel 1: Local Fused Reduction (Scalar Fallback) -// ---------------------------------------------------------------------------- -template -__global__ void local_reduce_scalar_kernel( - const T* __restrict__ X_hat, - const T* __restrict__ dY, - float* __restrict__ d_gamma_local, - float* __restrict__ d_beta_local, - int B, int H -) { - int h = blockIdx.x * blockDim.x + threadIdx.x; - if (h >= H) return; - - int b_chunk = (B + gridDim.y - 1) / gridDim.y; - int b_begin = blockIdx.y * b_chunk; - int b_end = b_begin + b_chunk; - if (b_end > B) b_end = B; - - float sg = 0.0f; - float sb = 0.0f; - - // Load elements natively, perform precision multiply-add in FP32 - for (int b = b_begin; b < b_end; ++b) { - float x = to_float(X_hat[b * H + h]); - float dy = to_float(dY[b * H + h]); - sg += dy * x; - sb += dy; - } - - // Safely scatter partial accumulation into the symmetric zeroed buffer - atomicAdd(&d_gamma_local[h], sg); - atomicAdd(&d_beta_local[h], sb); -} - -// ---------------------------------------------------------------------------- -// Kernel 1: Local Fused Reduction (Vectorized 8-element loads for BF16/FP16) -// ---------------------------------------------------------------------------- -template -__global__ void local_reduce_vec8_kernel( - const T* __restrict__ X_hat, - const T* __restrict__ dY, - float* __restrict__ d_gamma_local, - float* __restrict__ d_beta_local, - int B, int H -) { - int h_vec = blockIdx.x * blockDim.x + threadIdx.x; - int h_start = h_vec * 8; - if (h_start >= H) return; - - int b_chunk = (B + gridDim.y - 1) / gridDim.y; - int b_begin = blockIdx.y * b_chunk; - int b_end = b_begin + b_chunk; - if (b_end > B) b_end = B; - - float sg[8] = {0}; - float sb[8] = {0}; - - // Vectorized read mapping: 1x float4 grabs 16 bytes = 8x 16-bit elements - const float4* x_ptr = reinterpret_cast(X_hat); - const float4* dy_ptr = reinterpret_cast(dY); - int vec_H = H / 8; - - for (int b = b_begin; b < b_end; ++b) { - int idx = b * vec_H + h_vec; - float4 x_v = x_ptr[idx]; - float4 dy_v = dy_ptr[idx]; - - const T2* x_h2 = (const T2*)&x_v; - const T2* dy_h2 = (const T2*)&dy_v; - - #pragma unroll - for (int i = 0; i < 4; ++i) { - float2 x_f2 = cvt_float2(x_h2[i]); - float2 dy_f2 = cvt_float2(dy_h2[i]); - sg[i*2 + 0] += dy_f2.x * x_f2.x; - sg[i*2 + 1] += dy_f2.y * x_f2.y; - sb[i*2 + 0] += dy_f2.x; - sb[i*2 + 1] += dy_f2.y; - } - } - - // Scatter atomic adds blockwise chunk sums into unified global symmetric memory - for (int i = 0; i < 8; ++i) { - atomicAdd(&d_gamma_local[h_start + i], sg[i]); - atomicAdd(&d_beta_local[h_start + i], sb[i]); - } -} - -// ---------------------------------------------------------------------------- -// Kernel 2: NVLink Peer Pointers Cross-Rank Reduce -// ---------------------------------------------------------------------------- -template -__global__ void cross_rank_reduce_kernel( - const long long* __restrict__ ptrs, - T* __restrict__ out_gamma, - T* __restrict__ out_beta, - int world_size, - int H -) { - int h = blockIdx.x * blockDim.x + threadIdx.x; - if (h >= H) return; - - float sum_g = 0.0f; - float sum_b = 0.0f; - - #pragma unroll - for (int r = 0; r < world_size; ++r) { - const float* base = (const float*)ptrs[r]; - sum_g += base[h]; // First half of buffer is gamma - sum_b += base[H + h]; // Second half of buffer is beta - } - - out_gamma[h] = from_float(sum_g); - out_beta[h] = from_float(sum_b); -} - -// ---------------------------------------------------------------------------- -// C++ PyBind Dispatchers -// ---------------------------------------------------------------------------- -void launch_fused_layernorm_backward( - torch::Tensor X_hat, - torch::Tensor dY, - torch::Tensor buf, - int dtype_enum -) { - int B = X_hat.size(0); - int H = X_hat.size(1); - - // Spread along the batch dim if batch is large enough to saturate Hopper execution - int blocks_y = std::min(32, (B + 127) / 128); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - // 0 out symmetric partial accumulation buffer mapping before filling - cudaMemsetAsync(buf.data_ptr(), 0, 2 * H * sizeof(float), stream); - - float* d_gamma_local = buf.data_ptr(); - float* d_beta_local = buf.data_ptr() + H; - - if (dtype_enum == 0) { // bf16 - if (H % 8 == 0) { - dim3 threads(128); - dim3 blocks(H / 8 / 128 + (H / 8 % 128 != 0), blocks_y); - local_reduce_vec8_kernel<__nv_bfloat16, __nv_bfloat162><<>>( - reinterpret_cast(X_hat.data_ptr()), - reinterpret_cast(dY.data_ptr()), - d_gamma_local, d_beta_local, B, H - ); - } else { - dim3 threads(256); - dim3 blocks((H + 255) / 256, blocks_y); - local_reduce_scalar_kernel<__nv_bfloat16><<>>( - reinterpret_cast(X_hat.data_ptr()), - reinterpret_cast(dY.data_ptr()), - d_gamma_local, d_beta_local, B, H - ); - } - } else if (dtype_enum == 1) { // fp16 - if (H % 8 == 0) { - dim3 threads(128); - dim3 blocks(H / 8 / 128 + (H / 8 % 128 != 0), blocks_y); - local_reduce_vec8_kernel<__half, __half2><<>>( - reinterpret_cast(X_hat.data_ptr()), - reinterpret_cast(dY.data_ptr()), - d_gamma_local, d_beta_local, B, H - ); - } else { - dim3 threads(256); - dim3 blocks((H + 255) / 256, blocks_y); - local_reduce_scalar_kernel<__half><<>>( - reinterpret_cast(X_hat.data_ptr()), - reinterpret_cast(dY.data_ptr()), - d_gamma_local, d_beta_local, B, H - ); - } - } else { // fp32 - dim3 threads(256); - dim3 blocks((H + 255) / 256, blocks_y); - local_reduce_scalar_kernel<<>>( - X_hat.data_ptr(), dY.data_ptr(), - d_gamma_local, d_beta_local, B, H - ); - } -} - -void launch_cross_rank_reduce( - torch::Tensor ptrs, - torch::Tensor out_gamma, - torch::Tensor out_beta, - int world_size, - int H, - int dtype_enum -) { - int threads = 256; - int blocks = (H + threads - 1) / threads; - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - const long long* d_ptrs = (const long long*)ptrs.data_ptr(); - - if (dtype_enum == 0) { - cross_rank_reduce_kernel<__nv_bfloat16><<>>( - d_ptrs, - reinterpret_cast<__nv_bfloat16*>(out_gamma.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(out_beta.data_ptr()), - world_size, H - ); - } else if (dtype_enum == 1) { - cross_rank_reduce_kernel<__half><<>>( - d_ptrs, - reinterpret_cast<__half*>(out_gamma.data_ptr()), - reinterpret_cast<__half*>(out_beta.data_ptr()), - world_size, H - ); - } else { - cross_rank_reduce_kernel<<>>( - d_ptrs, out_gamma.data_ptr(), out_beta.data_ptr(), world_size, H - ); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_fused_layernorm_backward", &launch_fused_layernorm_backward, "Fused local layernorm backward over B"); - m.def("launch_cross_rank_reduce", &launch_cross_rank_reduce, "Cross-rank symmetric memory reduce"); -} -''' - -_ext = None -def _get_ext(): - global _ext - if _ext is None: - _ext = compile_cuda_extension("fused_layernorm_bw_ext", CUDA_SRC) - return _ext - -_symm_cache = {} - -def _get_resources(H: int, device: torch.device): - key = (H, device) - if key in _symm_cache: - return _symm_cache[key] - - # Pre-allocate a single FP32 symmetric buffer of size (2, H) to hold both gamma and beta local sums cleanly - buf = symm_mem.empty((2, H), device=device, dtype=torch.float32) - hdl = symm_mem.rendezvous(buf, dist.group.WORLD) - - # Create the tensor of UVA pointers holding symmetrical peering routes globally - ptrs_tensor = torch.tensor(hdl.buffer_ptrs, device=device, dtype=torch.int64) - - _symm_cache[key] = (buf, hdl, ptrs_tensor) - return buf, hdl, ptrs_tensor - -@torch.no_grad() -def solution( - X_hat: torch.Tensor, - dY: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - if not dist.is_initialized(): - d_beta = dY.sum(dim=0) - d_gamma = (dY * X_hat).sum(dim=0) - return d_gamma, d_beta - - B, H = X_hat.shape - dtype = X_hat.dtype - device = X_hat.device - - if dtype == torch.bfloat16: - dtype_enum = 0 - elif dtype == torch.float16: - dtype_enum = 1 - elif dtype == torch.float32: - dtype_enum = 2 - else: - # Graceful fallback for non-supported exotic dtypes - d_beta = dY.sum(dim=0) - d_gamma = (dY * X_hat).sum(dim=0) - dist.all_reduce(d_beta, op=dist.ReduceOp.SUM) - dist.all_reduce(d_gamma, op=dist.ReduceOp.SUM) - return d_gamma, d_beta - - buf, hdl, ptrs_tensor = _get_resources(H, device) - - if dist.get_rank() == 0: - _get_ext() - dist.barrier() - ext = _get_ext() - - # Phase 1: Local Fused Elementwise + Reduction into [2, H] Float32 Symmetric Memory Buffer - ext.launch_fused_layernorm_backward(X_hat.contiguous(), dY.contiguous(), buf, dtype_enum) - - # Device side sync acting natively on current stream ensures chunk stores have fully landed on all peers - hdl.barrier(channel=0) - - d_gamma = torch.empty(H, device=device, dtype=dtype) - d_beta = torch.empty(H, device=device, dtype=dtype) - - # Phase 2: Direct NVLink access cross-rank sum bypassing heavy NCCL dispatch - ext.launch_cross_rank_reduce(ptrs_tensor, d_gamma, d_beta, hdl.world_size, H, dtype_enum) - - return d_gamma, d_beta \ No newline at end of file